diff --git a/.github/actions/setup-go/action.yml b/.github/actions/setup-go/action.yml index aa6a50f517a..0f1313d6765 100644 --- a/.github/actions/setup-go/action.yml +++ b/.github/actions/setup-go/action.yml @@ -52,7 +52,7 @@ runs: # The key is used to create and later look up the cache. It's made of # four parts: # - The base part is made from the OS name, Go version and a - # job-specified key prefix. Example: `linux-go-1.23.10-unit-test-`. + # job-specified key prefix. Example: `linux-go-1.23.12-unit-test-`. # It ensures that a job running on Linux with Go 1.23 only looks for # caches from the same environment. # - The unique part is the `hashFiles('**/go.sum')`, which calculates a diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 550830d28a6..7c4868da806 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -40,7 +40,7 @@ env: # If you change this please also update GO_VERSION in Makefile (then run # `make lint` to see where else it needs to be updated as well). - GO_VERSION: 1.23.10 + GO_VERSION: 1.23.12 jobs: static-checks: diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 5224d153f33..915869017ce 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -12,7 +12,7 @@ defaults: env: # If you change this please also update GO_VERSION in Makefile (then run # `make lint` to see where else it needs to be updated as well). - GO_VERSION: 1.23.10 + GO_VERSION: 1.23.12 jobs: main: diff --git a/.golangci.yml b/.golangci.yml index 9f1ab49a546..eaf737177a6 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,7 +1,7 @@ run: # If you change this please also update GO_VERSION in Makefile (then run # `make lint` to see where else it needs to be updated as well). - go: "1.23.10" + go: "1.23.12" # Abort after 10 minutes. timeout: 10m diff --git a/Dockerfile b/Dockerfile index 54158638562..5cf4f96c71e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,6 +1,6 @@ # If you change this please also update GO_VERSION in Makefile (then run # `make lint` to see where else it needs to be updated as well). -FROM golang:1.23.10-alpine as builder +FROM golang:1.23.12-alpine as builder # Force Go to use the cgo based DNS resolver. This is required to ensure DNS # queries required to connect to linked containers succeed. diff --git a/Makefile b/Makefile index c9a3baca51a..97863b0e6b7 100644 --- a/Makefile +++ b/Makefile @@ -28,7 +28,7 @@ ACTIVE_GO_VERSION_MINOR := $(shell echo $(ACTIVE_GO_VERSION) | cut -d. -f2) # GO_VERSION is the Go version used for the release build, docker files, and # GitHub Actions. This is the reference version for the project. All other Go # versions are checked against this version. -GO_VERSION = 1.23.10 +GO_VERSION = 1.23.12 GOBUILD := $(GOCC) build -v GOINSTALL := $(GOCC) install -v diff --git a/channeldb/codec.go b/channeldb/codec.go index 8c39f4d7313..e5bab3d5f76 100644 --- a/channeldb/codec.go +++ b/channeldb/codec.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "net" + "time" "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcutil" @@ -182,11 +183,6 @@ func WriteElement(w io.Writer, element interface{}) error { return err } - case paymentIndexType: - if err := binary.Write(w, byteOrder, e); err != nil { - return err - } - case lnwire.FundingFlag: if err := binary.Write(w, byteOrder, e); err != nil { return err @@ -415,11 +411,6 @@ func ReadElement(r io.Reader, element interface{}) error { return err } - case *paymentIndexType: - if err := binary.Read(r, byteOrder, e); err != nil { - return err - } - case *lnwire.FundingFlag: if err := binary.Read(r, byteOrder, e); err != nil { return err @@ -466,3 +457,37 @@ func ReadElements(r io.Reader, elements ...interface{}) error { } return nil } + +// deserializeTime deserializes time as unix nanoseconds. +func deserializeTime(r io.Reader) (time.Time, error) { + var scratch [8]byte + if _, err := io.ReadFull(r, scratch[:]); err != nil { + return time.Time{}, err + } + + // Convert to time.Time. Interpret unix nano time zero as a zero + // time.Time value. + unixNano := byteOrder.Uint64(scratch[:]) + if unixNano == 0 { + return time.Time{}, nil + } + + return time.Unix(0, int64(unixNano)), nil +} + +// serializeTime serializes time as unix nanoseconds. +func serializeTime(w io.Writer, t time.Time) error { + var scratch [8]byte + + // Convert to unix nano seconds, but only if time is non-zero. Calling + // UnixNano() on a zero time yields an undefined result. + var unixNano int64 + if !t.IsZero() { + unixNano = t.UnixNano() + } + + byteOrder.PutUint64(scratch[:], uint64(unixNano)) + _, err := w.Write(scratch[:]) + + return err +} diff --git a/channeldb/db.go b/channeldb/db.go index 715b906867a..00b29f65f9f 100644 --- a/channeldb/db.go +++ b/channeldb/db.go @@ -203,11 +203,13 @@ var ( migration: mig.CreateTLB(payAddrIndexBucket), }, { - // Initialize payment index bucket which will be used - // to index payments by sequence number. This index will - // be used to allow more efficient ListPayments queries. - number: 15, - migration: mig.CreateTLB(paymentsIndexBucket), + // This used to be create payment related top-level + // buckets, however this is now done by the payment + // package. + number: 15, + migration: func(tx kvdb.RwTx) error { + return nil + }, }, { // Add our existing payments to the index bucket created @@ -450,7 +452,6 @@ var dbTopLevelBuckets = [][]byte{ invoiceBucket, payAddrIndexBucket, setIDIndexBucket, - paymentsIndexBucket, peersBucket, nodeInfoBucket, metaBucket, diff --git a/channeldb/invoices.go b/channeldb/invoices.go index 161fa4a2ee3..ab8d1426fa6 100644 --- a/channeldb/invoices.go +++ b/channeldb/invoices.go @@ -553,7 +553,7 @@ func (d *DB) QueryInvoices(_ context.Context, q invpkg.InvoiceQuery) ( // Create a paginator which reads from our add index bucket with // the parameters provided by the invoice query. - paginator := newPaginator( + paginator := NewPaginator( invoiceAddIndex.ReadCursor(), q.Reversed, q.IndexOffset, q.NumMaxInvoices, ) @@ -603,7 +603,7 @@ func (d *DB) QueryInvoices(_ context.Context, q invpkg.InvoiceQuery) ( // Query our paginator using accumulateInvoices to build up a // set of invoices. - if err := paginator.query(accumulateInvoices); err != nil { + if err := paginator.Query(accumulateInvoices); err != nil { return err } diff --git a/channeldb/mp_payment_test.go b/channeldb/mp_payment_test.go deleted file mode 100644 index 455a04de0a8..00000000000 --- a/channeldb/mp_payment_test.go +++ /dev/null @@ -1,603 +0,0 @@ -package channeldb - -import ( - "bytes" - "fmt" - "testing" - - "github.com/btcsuite/btcd/btcec/v2" - "github.com/lightningnetwork/lnd/lntypes" - "github.com/lightningnetwork/lnd/lnwire" - paymentsdb "github.com/lightningnetwork/lnd/payments/db" - "github.com/lightningnetwork/lnd/routing/route" - "github.com/stretchr/testify/require" -) - -var ( - testHash = [32]byte{ - 0xb7, 0x94, 0x38, 0x5f, 0x2d, 0x1e, 0xf7, 0xab, - 0x4d, 0x92, 0x73, 0xd1, 0x90, 0x63, 0x81, 0xb4, - 0x4f, 0x2f, 0x6f, 0x25, 0x88, 0xa3, 0xef, 0xb9, - 0x6a, 0x49, 0x18, 0x83, 0x31, 0x98, 0x47, 0x53, - } -) - -// TestLazySessionKeyDeserialize tests that we can read htlc attempt session -// keys that were previously serialized as a private key as raw bytes. -func TestLazySessionKeyDeserialize(t *testing.T) { - var b bytes.Buffer - - // Serialize as a private key. - err := WriteElements(&b, priv) - require.NoError(t, err) - - // Deserialize into [btcec.PrivKeyBytesLen]byte. - attempt := HTLCAttemptInfo{} - err = ReadElements(&b, &attempt.sessionKey) - require.NoError(t, err) - require.Zero(t, b.Len()) - - sessionKey := attempt.SessionKey() - require.Equal(t, priv, sessionKey) -} - -// TestRegistrable checks the method `Registrable` behaves as expected for ALL -// possible payment statuses. -func TestRegistrable(t *testing.T) { - t.Parallel() - - testCases := []struct { - status PaymentStatus - registryErr error - hasSettledHTLC bool - paymentFailed bool - }{ - { - status: StatusInitiated, - registryErr: nil, - }, - { - // Test inflight status with no settled HTLC and no - // failed payment. - status: StatusInFlight, - registryErr: nil, - }, - { - // Test inflight status with settled HTLC but no failed - // payment. - status: StatusInFlight, - registryErr: paymentsdb.ErrPaymentPendingSettled, - hasSettledHTLC: true, - }, - { - // Test inflight status with no settled HTLC but failed - // payment. - status: StatusInFlight, - registryErr: paymentsdb.ErrPaymentPendingFailed, - paymentFailed: true, - }, - { - // Test error state with settled HTLC and failed - // payment. - status: 0, - registryErr: paymentsdb.ErrUnknownPaymentStatus, - hasSettledHTLC: true, - paymentFailed: true, - }, - { - status: StatusSucceeded, - registryErr: paymentsdb.ErrPaymentAlreadySucceeded, - }, - { - status: StatusFailed, - registryErr: paymentsdb.ErrPaymentAlreadyFailed, - }, - { - status: 0, - registryErr: paymentsdb.ErrUnknownPaymentStatus, - }, - } - - for i, tc := range testCases { - i, tc := i, tc - - p := &MPPayment{ - Status: tc.status, - State: &MPPaymentState{ - HasSettledHTLC: tc.hasSettledHTLC, - PaymentFailed: tc.paymentFailed, - }, - } - - name := fmt.Sprintf("test_%d_%s", i, p.Status.String()) - t.Run(name, func(t *testing.T) { - t.Parallel() - - err := p.Registrable() - require.ErrorIs(t, err, tc.registryErr, - "registrable under state %v", tc.status) - }) - } -} - -// TestPaymentSetState checks that the method setState creates the -// MPPaymentState as expected. -func TestPaymentSetState(t *testing.T) { - t.Parallel() - - // Create a test preimage and failure reason. - preimage := lntypes.Preimage{1} - failureReasonError := FailureReasonError - - testCases := []struct { - name string - payment *MPPayment - totalAmt int - - expectedState *MPPaymentState - errExpected error - }{ - { - // Test that when the sentAmt exceeds totalAmount, the - // error is returned. - name: "amount exceeded error", - // SentAmt returns 90, 10 - // TerminalInfo returns non-nil, nil - // InFlightHTLCs returns 0 - payment: &MPPayment{ - HTLCs: []HTLCAttempt{ - makeSettledAttempt(100, 10, preimage), - }, - }, - totalAmt: 1, - errExpected: paymentsdb.ErrSentExceedsTotal, - }, - { - // Test that when the htlc is failed, the fee is not - // used. - name: "fee excluded for failed htlc", - payment: &MPPayment{ - // SentAmt returns 90, 10 - // TerminalInfo returns nil, nil - // InFlightHTLCs returns 1 - HTLCs: []HTLCAttempt{ - makeActiveAttempt(100, 10), - makeFailedAttempt(100, 10), - }, - }, - totalAmt: 1000, - expectedState: &MPPaymentState{ - NumAttemptsInFlight: 1, - RemainingAmt: 1000 - 90, - FeesPaid: 10, - HasSettledHTLC: false, - PaymentFailed: false, - }, - }, - { - // Test when the payment is settled, the state should - // be marked as terminated. - name: "payment settled", - // SentAmt returns 90, 10 - // TerminalInfo returns non-nil, nil - // InFlightHTLCs returns 0 - payment: &MPPayment{ - HTLCs: []HTLCAttempt{ - makeSettledAttempt(100, 10, preimage), - }, - }, - totalAmt: 1000, - expectedState: &MPPaymentState{ - NumAttemptsInFlight: 0, - RemainingAmt: 1000 - 90, - FeesPaid: 10, - HasSettledHTLC: true, - PaymentFailed: false, - }, - }, - { - // Test when the payment is failed, the state should be - // marked as terminated. - name: "payment failed", - // SentAmt returns 0, 0 - // TerminalInfo returns nil, non-nil - // InFlightHTLCs returns 0 - payment: &MPPayment{ - FailureReason: &failureReasonError, - }, - totalAmt: 1000, - expectedState: &MPPaymentState{ - NumAttemptsInFlight: 0, - RemainingAmt: 1000, - FeesPaid: 0, - HasSettledHTLC: false, - PaymentFailed: true, - }, - }, - } - - for _, tc := range testCases { - tc := tc - - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - // Attach the payment info. - info := &PaymentCreationInfo{ - Value: lnwire.MilliSatoshi(tc.totalAmt), - } - tc.payment.Info = info - - // Call the method that updates the payment state. - err := tc.payment.setState() - require.ErrorIs(t, err, tc.errExpected) - - require.Equal( - t, tc.expectedState, tc.payment.State, - "state not updated as expected", - ) - }) - } -} - -// TestNeedWaitAttempts checks whether we need to wait for the results of the -// HTLC attempts against ALL possible payment statuses. -func TestNeedWaitAttempts(t *testing.T) { - t.Parallel() - - testCases := []struct { - status PaymentStatus - remainingAmt lnwire.MilliSatoshi - hasSettledHTLC bool - hasFailureReason bool - needWait bool - expectedErr error - }{ - { - // For a newly created payment we don't need to wait - // for results. - status: StatusInitiated, - remainingAmt: 1000, - needWait: false, - expectedErr: nil, - }, - { - // With HTLCs inflight we don't need to wait when the - // remainingAmt is not zero and we have no settled - // HTLCs. - status: StatusInFlight, - remainingAmt: 1000, - needWait: false, - expectedErr: nil, - }, - { - // With HTLCs inflight we need to wait when the - // remainingAmt is not zero but we have settled HTLCs. - status: StatusInFlight, - remainingAmt: 1000, - hasSettledHTLC: true, - needWait: true, - expectedErr: nil, - }, - { - // With HTLCs inflight we need to wait when the - // remainingAmt is not zero and the payment is failed. - status: StatusInFlight, - remainingAmt: 1000, - needWait: true, - hasFailureReason: true, - expectedErr: nil, - }, - - { - // With the payment settled, but the remainingAmt is - // not zero, we have an error state. - status: StatusSucceeded, - remainingAmt: 1000, - needWait: false, - expectedErr: paymentsdb.ErrPaymentInternal, - }, - { - // Payment is in terminal state, no need to wait. - status: StatusFailed, - remainingAmt: 1000, - needWait: false, - expectedErr: nil, - }, - { - // A newly created payment with zero remainingAmt - // indicates an error. - status: StatusInitiated, - remainingAmt: 0, - needWait: false, - expectedErr: paymentsdb.ErrPaymentInternal, - }, - { - // With zero remainingAmt we must wait for the results. - status: StatusInFlight, - remainingAmt: 0, - needWait: true, - expectedErr: nil, - }, - { - // Payment is terminated, no need to wait for results. - status: StatusSucceeded, - remainingAmt: 0, - needWait: false, - expectedErr: nil, - }, - { - // Payment is terminated, no need to wait for results. - status: StatusFailed, - remainingAmt: 0, - needWait: false, - expectedErr: paymentsdb.ErrPaymentInternal, - }, - { - // Payment is in an unknown status, return an error. - status: 0, - remainingAmt: 0, - needWait: false, - expectedErr: paymentsdb.ErrUnknownPaymentStatus, - }, - { - // Payment is in an unknown status, return an error. - status: 0, - remainingAmt: 1000, - needWait: false, - expectedErr: paymentsdb.ErrUnknownPaymentStatus, - }, - } - - for _, tc := range testCases { - tc := tc - - p := &MPPayment{ - Info: &PaymentCreationInfo{ - PaymentIdentifier: [32]byte{1, 2, 3}, - }, - Status: tc.status, - State: &MPPaymentState{ - RemainingAmt: tc.remainingAmt, - HasSettledHTLC: tc.hasSettledHTLC, - PaymentFailed: tc.hasFailureReason, - }, - } - - name := fmt.Sprintf("status=%s|remainingAmt=%v|"+ - "settledHTLC=%v|failureReason=%v", tc.status, - tc.remainingAmt, tc.hasSettledHTLC, tc.hasFailureReason) - - t.Run(name, func(t *testing.T) { - t.Parallel() - - result, err := p.NeedWaitAttempts() - require.ErrorIs(t, err, tc.expectedErr) - require.Equalf(t, tc.needWait, result, "status=%v, "+ - "remainingAmt=%v", tc.status, tc.remainingAmt) - }) - } -} - -// TestAllowMoreAttempts checks whether more attempts can be created against -// ALL possible payment statuses. -func TestAllowMoreAttempts(t *testing.T) { - t.Parallel() - - testCases := []struct { - status PaymentStatus - remainingAmt lnwire.MilliSatoshi - hasSettledHTLC bool - paymentFailed bool - allowMore bool - expectedErr error - }{ - { - // A newly created payment with zero remainingAmt - // indicates an error. - status: StatusInitiated, - remainingAmt: 0, - allowMore: false, - expectedErr: paymentsdb.ErrPaymentInternal, - }, - { - // With zero remainingAmt we don't allow more HTLC - // attempts. - status: StatusInFlight, - remainingAmt: 0, - allowMore: false, - expectedErr: nil, - }, - { - // With zero remainingAmt we don't allow more HTLC - // attempts. - status: StatusSucceeded, - remainingAmt: 0, - allowMore: false, - expectedErr: nil, - }, - { - // With zero remainingAmt we don't allow more HTLC - // attempts. - status: StatusFailed, - remainingAmt: 0, - allowMore: false, - expectedErr: nil, - }, - { - // With zero remainingAmt and settled HTLCs we don't - // allow more HTLC attempts. - status: StatusInFlight, - remainingAmt: 0, - hasSettledHTLC: true, - allowMore: false, - expectedErr: nil, - }, - { - // With zero remainingAmt and failed payment we don't - // allow more HTLC attempts. - status: StatusInFlight, - remainingAmt: 0, - paymentFailed: true, - allowMore: false, - expectedErr: nil, - }, - { - // With zero remainingAmt and both settled HTLCs and - // failed payment, we don't allow more HTLC attempts. - status: StatusInFlight, - remainingAmt: 0, - hasSettledHTLC: true, - paymentFailed: true, - allowMore: false, - expectedErr: nil, - }, - { - // A newly created payment can have more attempts. - status: StatusInitiated, - remainingAmt: 1000, - allowMore: true, - expectedErr: nil, - }, - { - // With HTLCs inflight we can have more attempts when - // the remainingAmt is not zero and we have neither - // failed payment or settled HTLCs. - status: StatusInFlight, - remainingAmt: 1000, - allowMore: true, - expectedErr: nil, - }, - { - // With HTLCs inflight we cannot have more attempts - // though the remainingAmt is not zero but we have - // settled HTLCs. - status: StatusInFlight, - remainingAmt: 1000, - hasSettledHTLC: true, - allowMore: false, - expectedErr: nil, - }, - { - // With HTLCs inflight we cannot have more attempts - // though the remainingAmt is not zero but we have - // failed payment. - status: StatusInFlight, - remainingAmt: 1000, - paymentFailed: true, - allowMore: false, - expectedErr: nil, - }, - { - // With HTLCs inflight we cannot have more attempts - // though the remainingAmt is not zero but we have - // settled HTLCs and failed payment. - status: StatusInFlight, - remainingAmt: 1000, - hasSettledHTLC: true, - paymentFailed: true, - allowMore: false, - expectedErr: nil, - }, - { - // With the payment settled, but the remainingAmt is - // not zero, we have an error state. - status: StatusSucceeded, - remainingAmt: 1000, - hasSettledHTLC: true, - allowMore: false, - expectedErr: paymentsdb.ErrPaymentInternal, - }, - { - // With the payment failed with no inflight HTLCs, we - // don't allow more attempts to be made. - status: StatusFailed, - remainingAmt: 1000, - paymentFailed: true, - allowMore: false, - expectedErr: nil, - }, - { - // With the payment in an unknown state, we don't allow - // more attempts to be made. - status: 0, - remainingAmt: 1000, - allowMore: false, - expectedErr: nil, - }, - } - - for i, tc := range testCases { - tc := tc - - p := &MPPayment{ - Info: &PaymentCreationInfo{ - PaymentIdentifier: [32]byte{1, 2, 3}, - }, - Status: tc.status, - State: &MPPaymentState{ - RemainingAmt: tc.remainingAmt, - HasSettledHTLC: tc.hasSettledHTLC, - PaymentFailed: tc.paymentFailed, - }, - } - - name := fmt.Sprintf("test_%d|status=%s|remainingAmt=%v", i, - tc.status, tc.remainingAmt) - - t.Run(name, func(t *testing.T) { - t.Parallel() - - result, err := p.AllowMoreAttempts() - require.ErrorIs(t, err, tc.expectedErr) - require.Equalf(t, tc.allowMore, result, "status=%v, "+ - "remainingAmt=%v", tc.status, tc.remainingAmt) - }) - } -} - -func makeActiveAttempt(total, fee int) HTLCAttempt { - return HTLCAttempt{ - HTLCAttemptInfo: makeAttemptInfo(total, total-fee), - } -} - -func makeSettledAttempt(total, fee int, - preimage lntypes.Preimage) HTLCAttempt { - - return HTLCAttempt{ - HTLCAttemptInfo: makeAttemptInfo(total, total-fee), - Settle: &HTLCSettleInfo{Preimage: preimage}, - } -} - -func makeFailedAttempt(total, fee int) HTLCAttempt { - return HTLCAttempt{ - HTLCAttemptInfo: makeAttemptInfo(total, total-fee), - Failure: &HTLCFailInfo{ - Reason: HTLCFailInternal, - }, - } -} - -func makeAttemptInfo(total, amtForwarded int) HTLCAttemptInfo { - hop := &route.Hop{AmtToForward: lnwire.MilliSatoshi(amtForwarded)} - return HTLCAttemptInfo{ - Route: route.Route{ - TotalAmount: lnwire.MilliSatoshi(total), - Hops: []*route.Hop{hop}, - }, - } -} - -// TestEmptyRoutesGenerateSphinxPacket tests that the generateSphinxPacket -// function is able to gracefully handle being passed a nil set of hops for the -// route by the caller. -func TestEmptyRoutesGenerateSphinxPacket(t *testing.T) { - t.Parallel() - - sessionKey, _ := btcec.NewPrivateKey() - emptyRoute := &route.Route{} - _, _, err := generateSphinxPacket(emptyRoute, testHash[:], sessionKey) - require.ErrorIs(t, err, route.ErrNoRouteHopsProvided) -} diff --git a/channeldb/paginate.go b/channeldb/paginate.go index 496c236e051..dbb3548ebd4 100644 --- a/channeldb/paginate.go +++ b/channeldb/paginate.go @@ -16,9 +16,9 @@ type paginator struct { totalItems uint64 } -// newPaginator returns a struct which can be used to query an indexed bucket +// NewPaginator returns a struct which can be used to query an indexed bucket // in pages. -func newPaginator(c kvdb.RCursor, reversed bool, +func NewPaginator(c kvdb.RCursor, reversed bool, indexOffset, totalItems uint64) paginator { return paginator{ @@ -105,14 +105,14 @@ func (p paginator) cursorStart() ([]byte, []byte) { return indexKey, indexValue } -// query gets the start point for our index offset and iterates through keys +// Query gets the start point for our index offset and iterates through keys // in our index until we reach the total number of items required for the query // or we run out of cursor values. This function takes a fetchAndAppend function // which is responsible for looking up the entry at that index, adding the entry // to its set of return items (if desired) and return a boolean which indicates // whether the item was added. This is required to allow the paginator to // determine when the response has the maximum number of required items. -func (p paginator) query(fetchAndAppend func(k, v []byte) (bool, error)) error { +func (p paginator) Query(fetchAndAppend func(k, v []byte) (bool, error)) error { indexKey, indexValue := p.cursorStart() var totalItems int diff --git a/channeldb/payments.go b/channeldb/payments.go deleted file mode 100644 index a23f891d074..00000000000 --- a/channeldb/payments.go +++ /dev/null @@ -1,162 +0,0 @@ -package channeldb - -import ( - "fmt" - "time" - - "github.com/lightningnetwork/lnd/lntypes" - "github.com/lightningnetwork/lnd/lnwire" -) - -// FailureReason encodes the reason a payment ultimately failed. -type FailureReason byte - -const ( - // FailureReasonTimeout indicates that the payment did timeout before a - // successful payment attempt was made. - FailureReasonTimeout FailureReason = 0 - - // FailureReasonNoRoute indicates no successful route to the - // destination was found during path finding. - FailureReasonNoRoute FailureReason = 1 - - // FailureReasonError indicates that an unexpected error happened during - // payment. - FailureReasonError FailureReason = 2 - - // FailureReasonPaymentDetails indicates that either the hash is unknown - // or the final cltv delta or amount is incorrect. - FailureReasonPaymentDetails FailureReason = 3 - - // FailureReasonInsufficientBalance indicates that we didn't have enough - // balance to complete the payment. - FailureReasonInsufficientBalance FailureReason = 4 - - // FailureReasonCanceled indicates that the payment was canceled by the - // user. - FailureReasonCanceled FailureReason = 5 - - // TODO(joostjager): Add failure reasons for: - // LocalLiquidityInsufficient, RemoteCapacityInsufficient. -) - -// Error returns a human-readable error string for the FailureReason. -func (r FailureReason) Error() string { - return r.String() -} - -// String returns a human-readable FailureReason. -func (r FailureReason) String() string { - switch r { - case FailureReasonTimeout: - return "timeout" - case FailureReasonNoRoute: - return "no_route" - case FailureReasonError: - return "error" - case FailureReasonPaymentDetails: - return "incorrect_payment_details" - case FailureReasonInsufficientBalance: - return "insufficient_balance" - case FailureReasonCanceled: - return "canceled" - } - - return "unknown" -} - -// PaymentCreationInfo is the information necessary to have ready when -// initiating a payment, moving it into state InFlight. -type PaymentCreationInfo struct { - // PaymentIdentifier is the hash this payment is paying to in case of - // non-AMP payments, and the SetID for AMP payments. - PaymentIdentifier lntypes.Hash - - // Value is the amount we are paying. - Value lnwire.MilliSatoshi - - // CreationTime is the time when this payment was initiated. - CreationTime time.Time - - // PaymentRequest is the full payment request, if any. - PaymentRequest []byte - - // FirstHopCustomRecords are the TLV records that are to be sent to the - // first hop of this payment. These records will be transmitted via the - // wire message only and therefore do not affect the onion payload size. - FirstHopCustomRecords lnwire.CustomRecords -} - -// String returns a human-readable description of the payment creation info. -func (p *PaymentCreationInfo) String() string { - return fmt.Sprintf("payment_id=%v, amount=%v, created_at=%v", - p.PaymentIdentifier, p.Value, p.CreationTime) -} - -// PaymentsQuery represents a query to the payments database starting or ending -// at a certain offset index. The number of retrieved records can be limited. -type PaymentsQuery struct { - // IndexOffset determines the starting point of the payments query and - // is always exclusive. In normal order, the query starts at the next - // higher (available) index compared to IndexOffset. In reversed order, - // the query ends at the next lower (available) index compared to the - // IndexOffset. In the case of a zero index_offset, the query will start - // with the oldest payment when paginating forwards, or will end with - // the most recent payment when paginating backwards. - IndexOffset uint64 - - // MaxPayments is the maximal number of payments returned in the - // payments query. - MaxPayments uint64 - - // Reversed gives a meaning to the IndexOffset. If reversed is set to - // true, the query will fetch payments with indices lower than the - // IndexOffset, otherwise, it will return payments with indices greater - // than the IndexOffset. - Reversed bool - - // If IncludeIncomplete is true, then return payments that have not yet - // fully completed. This means that pending payments, as well as failed - // payments will show up if this field is set to true. - IncludeIncomplete bool - - // CountTotal indicates that all payments currently present in the - // payment index (complete and incomplete) should be counted. - CountTotal bool - - // CreationDateStart, expressed in Unix seconds, if set, filters out - // all payments with a creation date greater than or equal to it. - CreationDateStart int64 - - // CreationDateEnd, expressed in Unix seconds, if set, filters out all - // payments with a creation date less than or equal to it. - CreationDateEnd int64 -} - -// PaymentsResponse contains the result of a query to the payments database. -// It includes the set of payments that match the query and integers which -// represent the index of the first and last item returned in the series of -// payments. These integers allow callers to resume their query in the event -// that the query's response exceeds the max number of returnable events. -type PaymentsResponse struct { - // Payments is the set of payments returned from the database for the - // PaymentsQuery. - Payments []*MPPayment - - // FirstIndexOffset is the index of the first element in the set of - // returned MPPayments. Callers can use this to resume their query - // in the event that the slice has too many events to fit into a single - // response. The offset can be used to continue reverse pagination. - FirstIndexOffset uint64 - - // LastIndexOffset is the index of the last element in the set of - // returned MPPayments. Callers can use this to resume their query - // in the event that the slice has too many events to fit into a single - // response. The offset can be used to continue forward pagination. - LastIndexOffset uint64 - - // TotalCount represents the total number of payments that are currently - // stored in the payment database. This will only be set if the - // CountTotal field in the query was set to true. - TotalCount uint64 -} diff --git a/channeldb/payments_test.go b/channeldb/payments_test.go deleted file mode 100644 index dce993ffe5d..00000000000 --- a/channeldb/payments_test.go +++ /dev/null @@ -1,474 +0,0 @@ -package channeldb - -import ( - "context" - "fmt" - "math" - "reflect" - "testing" - "time" - - "github.com/btcsuite/btcd/btcec/v2" - "github.com/davecgh/go-spew/spew" - "github.com/lightningnetwork/lnd/record" - "github.com/lightningnetwork/lnd/routing/route" - "github.com/stretchr/testify/require" -) - -var ( - priv, _ = btcec.NewPrivateKey() - pub = priv.PubKey() - vertex = route.NewVertex(pub) - - testHop1 = &route.Hop{ - PubKeyBytes: vertex, - ChannelID: 12345, - OutgoingTimeLock: 111, - AmtToForward: 555, - CustomRecords: record.CustomSet{ - 65536: []byte{}, - 80001: []byte{}, - }, - MPP: record.NewMPP(32, [32]byte{0x42}), - Metadata: []byte{1, 2, 3}, - } - - testHop2 = &route.Hop{ - PubKeyBytes: vertex, - ChannelID: 12345, - OutgoingTimeLock: 111, - AmtToForward: 555, - LegacyPayload: true, - } - - testHop3 = &route.Hop{ - PubKeyBytes: route.NewVertex(pub), - ChannelID: 12345, - OutgoingTimeLock: 111, - AmtToForward: 555, - CustomRecords: record.CustomSet{ - 65536: []byte{}, - 80001: []byte{}, - }, - AMP: record.NewAMP([32]byte{0x69}, [32]byte{0x42}, 1), - Metadata: []byte{1, 2, 3}, - } - - testRoute = route.Route{ - TotalTimeLock: 123, - TotalAmount: 1234567, - SourcePubKey: vertex, - Hops: []*route.Hop{ - testHop2, - testHop1, - }, - } - - testBlindedRoute = route.Route{ - TotalTimeLock: 150, - TotalAmount: 1000, - SourcePubKey: vertex, - Hops: []*route.Hop{ - { - PubKeyBytes: vertex, - ChannelID: 9876, - OutgoingTimeLock: 120, - AmtToForward: 900, - EncryptedData: []byte{1, 3, 3}, - BlindingPoint: pub, - }, - { - PubKeyBytes: vertex, - EncryptedData: []byte{3, 2, 1}, - }, - { - PubKeyBytes: vertex, - Metadata: []byte{4, 5, 6}, - AmtToForward: 500, - OutgoingTimeLock: 100, - TotalAmtMsat: 500, - }, - }, - } -) - -// assertRouteEquals compares to routes for equality and returns an error if -// they are not equal. -func assertRouteEqual(a, b *route.Route) error { - if !reflect.DeepEqual(a, b) { - return fmt.Errorf("HTLCAttemptInfos don't match: %v vs %v", - spew.Sdump(a), spew.Sdump(b)) - } - - return nil -} - -// TestQueryPayments tests retrieval of payments with forwards and reversed -// queries. -func TestQueryPayments(t *testing.T) { - // Define table driven test for QueryPayments. - // Test payments have sequence indices [1, 3, 4, 5, 6, 7]. - // Note that the payment with index 7 has the same payment hash as 6, - // and is stored in a nested bucket within payment 6 rather than being - // its own entry in the payments bucket. We do this to test retrieval - // of legacy payments. - tests := []struct { - name string - query PaymentsQuery - firstIndex uint64 - lastIndex uint64 - - // expectedSeqNrs contains the set of sequence numbers we expect - // our query to return. - expectedSeqNrs []uint64 - }{ - { - name: "IndexOffset at the end of the payments range", - query: PaymentsQuery{ - IndexOffset: 7, - MaxPayments: 7, - Reversed: false, - IncludeIncomplete: true, - }, - firstIndex: 0, - lastIndex: 0, - expectedSeqNrs: nil, - }, - { - name: "query in forwards order, start at beginning", - query: PaymentsQuery{ - IndexOffset: 0, - MaxPayments: 2, - Reversed: false, - IncludeIncomplete: true, - }, - firstIndex: 1, - lastIndex: 3, - expectedSeqNrs: []uint64{1, 3}, - }, - { - name: "query in forwards order, start at end, overflow", - query: PaymentsQuery{ - IndexOffset: 6, - MaxPayments: 2, - Reversed: false, - IncludeIncomplete: true, - }, - firstIndex: 7, - lastIndex: 7, - expectedSeqNrs: []uint64{7}, - }, - { - name: "start at offset index outside of payments", - query: PaymentsQuery{ - IndexOffset: 20, - MaxPayments: 2, - Reversed: false, - IncludeIncomplete: true, - }, - firstIndex: 0, - lastIndex: 0, - expectedSeqNrs: nil, - }, - { - name: "overflow in forwards order", - query: PaymentsQuery{ - IndexOffset: 4, - MaxPayments: math.MaxUint64, - Reversed: false, - IncludeIncomplete: true, - }, - firstIndex: 5, - lastIndex: 7, - expectedSeqNrs: []uint64{5, 6, 7}, - }, - { - name: "start at offset index outside of payments, " + - "reversed order", - query: PaymentsQuery{ - IndexOffset: 9, - MaxPayments: 2, - Reversed: true, - IncludeIncomplete: true, - }, - firstIndex: 6, - lastIndex: 7, - expectedSeqNrs: []uint64{6, 7}, - }, - { - name: "query in reverse order, start at end", - query: PaymentsQuery{ - IndexOffset: 0, - MaxPayments: 2, - Reversed: true, - IncludeIncomplete: true, - }, - firstIndex: 6, - lastIndex: 7, - expectedSeqNrs: []uint64{6, 7}, - }, - { - name: "query in reverse order, starting in middle", - query: PaymentsQuery{ - IndexOffset: 4, - MaxPayments: 2, - Reversed: true, - IncludeIncomplete: true, - }, - firstIndex: 1, - lastIndex: 3, - expectedSeqNrs: []uint64{1, 3}, - }, - { - name: "query in reverse order, starting in middle, " + - "with underflow", - query: PaymentsQuery{ - IndexOffset: 4, - MaxPayments: 5, - Reversed: true, - IncludeIncomplete: true, - }, - firstIndex: 1, - lastIndex: 3, - expectedSeqNrs: []uint64{1, 3}, - }, - { - name: "all payments in reverse, order maintained", - query: PaymentsQuery{ - IndexOffset: 0, - MaxPayments: 7, - Reversed: true, - IncludeIncomplete: true, - }, - firstIndex: 1, - lastIndex: 7, - expectedSeqNrs: []uint64{1, 3, 4, 5, 6, 7}, - }, - { - name: "exclude incomplete payments", - query: PaymentsQuery{ - IndexOffset: 0, - MaxPayments: 7, - Reversed: false, - IncludeIncomplete: false, - }, - firstIndex: 7, - lastIndex: 7, - expectedSeqNrs: []uint64{7}, - }, - { - name: "query payments at index gap", - query: PaymentsQuery{ - IndexOffset: 1, - MaxPayments: 7, - Reversed: false, - IncludeIncomplete: true, - }, - firstIndex: 3, - lastIndex: 7, - expectedSeqNrs: []uint64{3, 4, 5, 6, 7}, - }, - { - name: "query payments reverse before index gap", - query: PaymentsQuery{ - IndexOffset: 3, - MaxPayments: 7, - Reversed: true, - IncludeIncomplete: true, - }, - firstIndex: 1, - lastIndex: 1, - expectedSeqNrs: []uint64{1}, - }, - { - name: "query payments reverse on index gap", - query: PaymentsQuery{ - IndexOffset: 2, - MaxPayments: 7, - Reversed: true, - IncludeIncomplete: true, - }, - firstIndex: 1, - lastIndex: 1, - expectedSeqNrs: []uint64{1}, - }, - { - name: "query payments forward on index gap", - query: PaymentsQuery{ - IndexOffset: 2, - MaxPayments: 2, - Reversed: false, - IncludeIncomplete: true, - }, - firstIndex: 3, - lastIndex: 4, - expectedSeqNrs: []uint64{3, 4}, - }, - { - name: "query in forwards order, with start creation " + - "time", - query: PaymentsQuery{ - IndexOffset: 0, - MaxPayments: 2, - Reversed: false, - IncludeIncomplete: true, - CreationDateStart: 5, - }, - firstIndex: 5, - lastIndex: 6, - expectedSeqNrs: []uint64{5, 6}, - }, - { - name: "query in forwards order, with start creation " + - "time at end, overflow", - query: PaymentsQuery{ - IndexOffset: 0, - MaxPayments: 2, - Reversed: false, - IncludeIncomplete: true, - CreationDateStart: 7, - }, - firstIndex: 7, - lastIndex: 7, - expectedSeqNrs: []uint64{7}, - }, - { - name: "query with start and end creation time", - query: PaymentsQuery{ - IndexOffset: 9, - MaxPayments: math.MaxUint64, - Reversed: true, - IncludeIncomplete: true, - CreationDateStart: 3, - CreationDateEnd: 5, - }, - firstIndex: 3, - lastIndex: 5, - expectedSeqNrs: []uint64{3, 4, 5}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - ctx := context.Background() - - db, err := MakeTestDB(t) - require.NoError(t, err) - - // Initialize the payment database. - paymentDB, err := NewKVPaymentsDB(db) - require.NoError(t, err) - - // Make a preliminary query to make sure it's ok to - // query when we have no payments. - resp, err := paymentDB.QueryPayments(ctx, tt.query) - require.NoError(t, err) - require.Len(t, resp.Payments, 0) - - // Populate the database with a set of test payments. - // We create 6 original payments, deleting the payment - // at index 2 so that we cover the case where sequence - // numbers are missing. We also add a duplicate payment - // to the last payment added to test the legacy case - // where we have duplicates in the nested duplicates - // bucket. - nonDuplicatePayments := 6 - - for i := 0; i < nonDuplicatePayments; i++ { - // Generate a test payment. - info, _, preimg, err := genInfo(t) - if err != nil { - t.Fatalf("unable to create test "+ - "payment: %v", err) - } - // Override creation time to allow for testing - // of CreationDateStart and CreationDateEnd. - info.CreationTime = time.Unix(int64(i+1), 0) - - // Create a new payment entry in the database. - err = paymentDB.InitPayment( - info.PaymentIdentifier, info, - ) - require.NoError(t, err) - - // Immediately delete the payment with index 2. - if i == 1 { - pmt, err := paymentDB.FetchPayment( - info.PaymentIdentifier, - ) - require.NoError(t, err) - - deletePayment( - t, db, info.PaymentIdentifier, - pmt.SequenceNum, - ) - } - - // If we are on the last payment entry, add a - // duplicate payment with sequence number equal - // to the parent payment + 1. Note that - // duplicate payments will always be succeeded. - if i == (nonDuplicatePayments - 1) { - pmt, err := paymentDB.FetchPayment( - info.PaymentIdentifier, - ) - require.NoError(t, err) - - appendDuplicatePayment( - t, paymentDB.db, - info.PaymentIdentifier, - pmt.SequenceNum+1, - preimg, - ) - } - } - - // Fetch all payments in the database. - allPayments, err := paymentDB.FetchPayments() - if err != nil { - t.Fatalf("payments could not be fetched from "+ - "database: %v", err) - } - - if len(allPayments) != 6 { - t.Fatalf("Number of payments received does "+ - "not match expected one. Got %v, "+ - "want %v.", len(allPayments), 6) - } - - querySlice, err := paymentDB.QueryPayments( - ctx, tt.query, - ) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if tt.firstIndex != querySlice.FirstIndexOffset || - tt.lastIndex != querySlice.LastIndexOffset { - - t.Errorf("First or last index does not match "+ - "expected index. Want (%d, %d), "+ - "got (%d, %d).", - tt.firstIndex, tt.lastIndex, - querySlice.FirstIndexOffset, - querySlice.LastIndexOffset) - } - - if len(querySlice.Payments) != len(tt.expectedSeqNrs) { - t.Errorf("expected: %v payments, got: %v", - len(tt.expectedSeqNrs), - len(querySlice.Payments)) - } - - for i, seqNr := range tt.expectedSeqNrs { - q := querySlice.Payments[i] - if seqNr != q.SequenceNum { - t.Errorf("sequence numbers do not "+ - "match, got %v, want %v", - q.SequenceNum, seqNr) - } - } - }) - } -} diff --git a/config_builder.go b/config_builder.go index 47eb96b50aa..d7163066019 100644 --- a/config_builder.go +++ b/config_builder.go @@ -925,9 +925,9 @@ type DatabaseInstances struct { // InvoiceDB is the database that stores information about invoices. InvoiceDB invoices.InvoiceDB - // KVPaymentsDB is the database that stores all payment related + // PaymentsDB is the database that stores all payment related // information. - KVPaymentsDB *channeldb.KVPaymentsDB + PaymentsDB paymentsdb.PaymentDB // MacaroonDB is the database that stores macaroon root keys. MacaroonDB kvdb.Backend @@ -1226,7 +1226,7 @@ func (d *DefaultDatabaseBuilder) BuildDatabase( cfg.KeepFailedPaymentAttempts, ), } - kvPaymentsDB, err := channeldb.NewKVPaymentsDB( + kvPaymentsDB, err := paymentsdb.NewKVPaymentsDB( dbs.ChanStateDB, paymentsDBOptions..., ) @@ -1238,7 +1238,7 @@ func (d *DefaultDatabaseBuilder) BuildDatabase( return nil, nil, err } - dbs.KVPaymentsDB = kvPaymentsDB + dbs.PaymentsDB = kvPaymentsDB // Wrap the watchtower client DB and make sure we clean up. if cfg.WtClient.Active { diff --git a/dev.Dockerfile b/dev.Dockerfile index 6a3036c7ecd..aa584110b05 100644 --- a/dev.Dockerfile +++ b/dev.Dockerfile @@ -1,6 +1,6 @@ # If you change this please also update GO_VERSION in Makefile (then run # `make lint` to see where else it needs to be updated as well). -FROM golang:1.23.10-alpine AS builder +FROM golang:1.23.12-alpine AS builder LABEL maintainer="Olaoluwa Osuntokun " diff --git a/docker/btcd/Dockerfile b/docker/btcd/Dockerfile index 0af869276ff..22d48a8841b 100644 --- a/docker/btcd/Dockerfile +++ b/docker/btcd/Dockerfile @@ -1,6 +1,6 @@ # If you change this please also update GO_VERSION in Makefile (then run # `make lint` to see where else it needs to be updated as well). -FROM golang:1.23.10-alpine as builder +FROM golang:1.23.12-alpine as builder LABEL maintainer="Olaoluwa Osuntokun " diff --git a/docs/INSTALL.md b/docs/INSTALL.md index 1f8244c611a..b953b6fdd18 100644 --- a/docs/INSTALL.md +++ b/docs/INSTALL.md @@ -93,7 +93,7 @@ following build dependencies are required: ### Installing Go -`lnd` is written in Go, with a minimum version of `1.23.10` (or, in case this +`lnd` is written in Go, with a minimum version of `1.23.12` (or, in case this document gets out of date, whatever the Go version in the main `go.mod` file requires). To install, run one of the following commands for your OS: @@ -101,16 +101,16 @@ requires). To install, run one of the following commands for your OS: Linux (x86-64) ``` - wget https://dl.google.com/go/go1.23.10.linux-amd64.tar.gz - sha256sum go1.23.10.linux-amd64.tar.gz | awk -F " " '{ print $1 }' + wget https://dl.google.com/go/go1.23.12.linux-amd64.tar.gz + sha256sum go1.23.12.linux-amd64.tar.gz | awk -F " " '{ print $1 }' ``` The final output of the command above should be - `535f9f81802499f2a7dbfa70abb8fda3793725fcc29460f719815f6e10b5fd60`. If it + `d3847fef834e9db11bf64e3fb34db9c04db14e068eeb064f49af747010454f90`. If it isn't, then the target REPO HAS BEEN MODIFIED, and you shouldn't install this version of Go. If it matches, then proceed to install Go: ``` - sudo rm -rf /usr/local/go && sudo tar -C /usr/local -xzf go1.23.10.linux-amd64.tar.gz + sudo rm -rf /usr/local/go && sudo tar -C /usr/local -xzf go1.23.12.linux-amd64.tar.gz export PATH=$PATH:/usr/local/go/bin ``` @@ -119,16 +119,16 @@ requires). To install, run one of the following commands for your OS: Linux (ARMv6) ``` - wget https://dl.google.com/go/go1.23.10.linux-armv6l.tar.gz - sha256sum go1.23.10.linux-armv6l.tar.gz | awk -F " " '{ print $1 }' + wget https://dl.google.com/go/go1.23.12.linux-armv6l.tar.gz + sha256sum go1.23.12.linux-armv6l.tar.gz | awk -F " " '{ print $1 }' ``` The final output of the command above should be - `b6e00c9a72406d394b9f167e74670e28b72ed559cca8115b21be1cb9d5316cb4`. If it + `9704eba01401a3793f54fac162164b9c5d8cc6f3cab5cee72684bb72294d9f41`. If it isn't, then the target REPO HAS BEEN MODIFIED, and you shouldn't install this version of Go. If it matches, then proceed to install Go: ``` - sudo rm -rf /usr/local/go && tar -C /usr/local -xzf go1.23.10.linux-armv6l.tar.gz + sudo rm -rf /usr/local/go && tar -C /usr/local -xzf go1.23.12.linux-armv6l.tar.gz export PATH=$PATH:/usr/local/go/bin ``` diff --git a/docs/release-notes/release-notes-0.19.3.md b/docs/release-notes/release-notes-0.19.3.md index 5179a89ac3a..a6a9bf85d79 100644 --- a/docs/release-notes/release-notes-0.19.3.md +++ b/docs/release-notes/release-notes-0.19.3.md @@ -42,11 +42,15 @@ situations where the sending amount would violate the channel policy restriction (min,max HTLC). +- [Fixed](https://github.com/lightningnetwork/lnd/pull/10141) a case where we + would not resolve all outstanding payment attempts after the overall payment + lifecycle was canceled due to a timeout. + # New Features ## Functional Enhancements -* The default value for `gossip.msg-rate-bytes` has been +- The default value for `gossip.msg-rate-bytes` has been [increased](https://github.com/lightningnetwork/lnd/pull/10096) from 100KB to 1MB, and `gossip.msg-burst-bytes` has been increased from 200KB to 2MB. @@ -87,6 +91,9 @@ ## Code Health +- [The Golang version used was bumped to `v1.23.12` to fix a potential issue + with the SQL API](https://github.com/lightningnetwork/lnd/pull/10138). + ## Tooling and Documentation # Contributors (Alphabetical Order) diff --git a/docs/release-notes/release-notes-0.20.0.md b/docs/release-notes/release-notes-0.20.0.md index 0c3df271aab..579c5f3feef 100644 --- a/docs/release-notes/release-notes-0.20.0.md +++ b/docs/release-notes/release-notes-0.20.0.md @@ -40,6 +40,13 @@ # New Features +- Added [NoOp HTLCs](https://github.com/lightningnetwork/lnd/pull/9871). This +allows sending HTLCs to the remote party without shifting the balances of the +channel. This is currently only possible to use with custom channels, and only +when the appropriate TLV flag is set. This allows for HTLCs carrying metadata to +reflect their state on the channel commitment without having to send or receive +a certain amount of msats. + ## Functional Enhancements * RPCs `walletrpc.EstimateFee` and `walletrpc.FundPsbt` now diff --git a/go.mod b/go.mod index 06be5914538..c321cf97724 100644 --- a/go.mod +++ b/go.mod @@ -220,6 +220,6 @@ replace google.golang.org/protobuf => github.com/lightninglabs/protobuf-go-hex-d // If you change this please also update docs/INSTALL.md and GO_VERSION in // Makefile (then run `make lint` to see where else it needs to be updated as // well). -go 1.23.10 +go 1.23.12 retract v0.0.2 diff --git a/graph/db/benchmark_test.go b/graph/db/benchmark_test.go index 08dba16de56..1022197c591 100644 --- a/graph/db/benchmark_test.go +++ b/graph/db/benchmark_test.go @@ -736,6 +736,15 @@ func BenchmarkGraphReadMethods(b *testing.B) { require.NoError(b, err) }, }, + { + name: "ChanUpdatesInHorizon", + fn: func(b testing.TB, store V1Store) { + _, err := store.ChanUpdatesInHorizon( + time.Unix(0, 0), time.Now(), + ) + require.NoError(b, err) + }, + }, } for _, test := range tests { diff --git a/graph/db/kv_store.go b/graph/db/kv_store.go index 49191d7c2c4..1abe34ba4c3 100644 --- a/graph/db/kv_store.go +++ b/graph/db/kv_store.go @@ -2789,7 +2789,8 @@ func (c *KVStore) delChannelEdgeUnsafe(edges, edgeIndex, chanIndex, } nodeKey1, nodeKey2 = makeZombiePubkeys( - &edgeInfo, e1UpdateTime, e2UpdateTime, + edgeInfo.NodeKey1Bytes, edgeInfo.NodeKey2Bytes, + e1UpdateTime, e2UpdateTime, ) } @@ -2814,27 +2815,27 @@ func (c *KVStore) delChannelEdgeUnsafe(edges, edgeIndex, chanIndex, // the channel. If the channel were to be marked zombie again, it would be // marked with the correct lagging channel since we received an update from only // one side. -func makeZombiePubkeys(info *models.ChannelEdgeInfo, - e1, e2 *time.Time) ([33]byte, [33]byte) { +func makeZombiePubkeys(node1, node2 [33]byte, e1, e2 *time.Time) ([33]byte, + [33]byte) { switch { // If we don't have either edge policy, we'll return both pubkeys so // that the channel can be resurrected by either party. case e1 == nil && e2 == nil: - return info.NodeKey1Bytes, info.NodeKey2Bytes + return node1, node2 // If we're missing edge1, or if both edges are present but edge1 is // older, we'll return edge1's pubkey and a blank pubkey for edge2. This // means that only an update from edge1 will be able to resurrect the // channel. case e1 == nil || (e2 != nil && e1.Before(*e2)): - return info.NodeKey1Bytes, [33]byte{} + return node1, [33]byte{} // Otherwise, we're missing edge2 or edge2 is the older side, so we // return a blank pubkey for edge1. In this case, only an update from // edge2 can resurect the channel. default: - return [33]byte{}, info.NodeKey2Bytes + return [33]byte{}, node1 } } diff --git a/graph/db/sql_migration.go b/graph/db/sql_migration.go index e1ff256a256..0b06e7cac20 100644 --- a/graph/db/sql_migration.go +++ b/graph/db/sql_migration.go @@ -192,7 +192,7 @@ func migrateNodes(ctx context.Context, kvBackend kvdb.Backend, pub, id, dbNode.ID) } - migratedNode, err := buildNode(ctx, sqlDB, &dbNode) + migratedNode, err := buildNode(ctx, sqlDB, dbNode) if err != nil { return fmt.Errorf("could not build migrated node "+ "from dbNode(db id: %d, node pub: %x): %w", @@ -410,6 +410,8 @@ func migrateChannelsAndPolicies(ctx context.Context, kvBackend kvdb.Backend, } channelCount++ + chunk++ + err = migrateSingleChannel( ctx, sqlDB, channel, policy1, policy2, migChanPolicy, ) diff --git a/graph/db/sql_store.go b/graph/db/sql_store.go index 7a798785506..3276f67d706 100644 --- a/graph/db/sql_store.go +++ b/graph/db/sql_store.go @@ -911,6 +911,7 @@ func (s *SQLStore) ChanUpdatesInHorizon(startTime, edges []ChannelEdge hits int ) + err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error { rows, err := db.GetChannelsByPolicyLastUpdateRange( ctx, sqlc.GetChannelsByPolicyLastUpdateRangeParams{ @@ -923,72 +924,61 @@ func (s *SQLStore) ChanUpdatesInHorizon(startTime, return err } + if len(rows) == 0 { + return nil + } + + // We'll pre-allocate the slices and maps here with a best + // effort size in order to avoid unnecessary allocations later + // on. + uncachedRows := make( + []sqlc.GetChannelsByPolicyLastUpdateRangeRow, 0, + len(rows), + ) + edgesToCache = make(map[uint64]ChannelEdge, len(rows)) + edgesSeen = make(map[uint64]struct{}, len(rows)) + edges = make([]ChannelEdge, 0, len(rows)) + + // Separate cached from non-cached channels since we will only + // batch load the data for the ones we haven't cached yet. for _, row := range rows { - // If we've already retrieved the info and policies for - // this edge, then we can skip it as we don't need to do - // so again. chanIDInt := byteOrder.Uint64(row.GraphChannel.Scid) + + // Skip duplicates. if _, ok := edgesSeen[chanIDInt]; ok { continue } + edgesSeen[chanIDInt] = struct{}{} + // Check cache first. if channel, ok := s.chanCache.get(chanIDInt); ok { hits++ - edgesSeen[chanIDInt] = struct{}{} edges = append(edges, channel) - continue } - node1, node2, err := buildNodes( - ctx, db, row.GraphNode, row.GraphNode_2, - ) - if err != nil { - return err - } - - channel, err := getAndBuildEdgeInfo( - ctx, db, s.cfg.ChainHash, row.GraphChannel, - node1.PubKeyBytes, node2.PubKeyBytes, - ) - if err != nil { - return fmt.Errorf("unable to build channel "+ - "info: %w", err) - } - - dbPol1, dbPol2, err := extractChannelPolicies(row) - if err != nil { - return fmt.Errorf("unable to extract channel "+ - "policies: %w", err) - } + // Mark this row as one we need to batch load data for. + uncachedRows = append(uncachedRows, row) + } - p1, p2, err := getAndBuildChanPolicies( - ctx, db, dbPol1, dbPol2, channel.ChannelID, - node1.PubKeyBytes, node2.PubKeyBytes, - ) - if err != nil { - return fmt.Errorf("unable to build channel "+ - "policies: %w", err) - } + // If there are no uncached rows, then we can return early. + if len(uncachedRows) == 0 { + return nil + } - edgesSeen[chanIDInt] = struct{}{} - chanEdge := ChannelEdge{ - Info: channel, - Policy1: p1, - Policy2: p2, - Node1: node1, - Node2: node2, - } - edges = append(edges, chanEdge) - edgesToCache[chanIDInt] = chanEdge + // Batch load data for all uncached channels. + newEdges, err := batchBuildChannelEdges( + ctx, s.cfg, db, uncachedRows, + ) + if err != nil { + return fmt.Errorf("unable to batch build channel "+ + "edges: %w", err) } + edges = append(edges, newEdges...) + return nil - }, func() { - edgesSeen = make(map[uint64]struct{}) - edgesToCache = make(map[uint64]ChannelEdge) - edges = nil - }) + }, sqldb.NoOpReset) if err != nil { return nil, fmt.Errorf("unable to fetch channels: %w", err) } @@ -1625,11 +1615,12 @@ func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool, } var ( - ctx = context.TODO() - deleted []*models.ChannelEdgeInfo + ctx = context.TODO() + edges []*models.ChannelEdgeInfo ) err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error { - chanIDsToDelete := make([]int64, 0, len(chanIDs)) + // First, collect all channel rows. + var channelRows []sqlc.GetChannelsBySCIDWithPoliciesRow chanCallBack := func(ctx context.Context, row sqlc.GetChannelsBySCIDWithPoliciesRow) error { @@ -1638,64 +1629,7 @@ func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool, scid := byteOrder.Uint64(row.GraphChannel.Scid) delete(chanLookup, scid) - node1, node2, err := buildNodeVertices( - row.GraphNode.PubKey, row.GraphNode_2.PubKey, - ) - if err != nil { - return err - } - - info, err := getAndBuildEdgeInfo( - ctx, db, s.cfg.ChainHash, row.GraphChannel, - node1, node2, - ) - if err != nil { - return err - } - - deleted = append(deleted, info) - chanIDsToDelete = append( - chanIDsToDelete, row.GraphChannel.ID, - ) - - if !markZombie { - return nil - } - - nodeKey1, nodeKey2 := info.NodeKey1Bytes, - info.NodeKey2Bytes - if strictZombiePruning { - var e1UpdateTime, e2UpdateTime *time.Time - if row.Policy1LastUpdate.Valid { - e1Time := time.Unix( - row.Policy1LastUpdate.Int64, 0, - ) - e1UpdateTime = &e1Time - } - if row.Policy2LastUpdate.Valid { - e2Time := time.Unix( - row.Policy2LastUpdate.Int64, 0, - ) - e2UpdateTime = &e2Time - } - - nodeKey1, nodeKey2 = makeZombiePubkeys( - info, e1UpdateTime, e2UpdateTime, - ) - } - - err = db.UpsertZombieChannel( - ctx, sqlc.UpsertZombieChannelParams{ - Version: int16(ProtocolV1), - Scid: channelIDToBytes(scid), - NodeKey1: nodeKey1[:], - NodeKey2: nodeKey2[:], - }, - ) - if err != nil { - return fmt.Errorf("unable to mark channel as "+ - "zombie: %w", err) - } + channelRows = append(channelRows, row) return nil } @@ -1711,9 +1645,37 @@ func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool, return ErrEdgeNotFound } + if len(channelRows) == 0 { + return nil + } + + // Batch build all channel edges. + var chanIDsToDelete []int64 + edges, chanIDsToDelete, err = batchBuildChannelInfo( + ctx, s.cfg, db, channelRows, + ) + if err != nil { + return err + } + + if markZombie { + for i, row := range channelRows { + scid := byteOrder.Uint64(row.GraphChannel.Scid) + + err := handleZombieMarking( + ctx, db, row, edges[i], + strictZombiePruning, scid, + ) + if err != nil { + return fmt.Errorf("unable to mark "+ + "channel as zombie: %w", err) + } + } + } + return s.deleteChannels(ctx, db, chanIDsToDelete) }, func() { - deleted = nil + edges = nil // Re-fill the lookup map. for _, chanID := range chanIDs { @@ -1730,7 +1692,7 @@ func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool, s.chanCache.remove(chanID) } - return deleted, nil + return edges, nil } // FetchChannelEdgesByID attempts to lookup the two directed edges for the @@ -2093,55 +2055,40 @@ func (s *SQLStore) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) { edges = make(map[uint64]ChannelEdge) ) err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error { + // First, collect all channel rows. + var channelRows []sqlc.GetChannelsBySCIDWithPoliciesRow chanCallBack := func(ctx context.Context, row sqlc.GetChannelsBySCIDWithPoliciesRow) error { - node1, node2, err := buildNodes( - ctx, db, row.GraphNode, row.GraphNode_2, - ) - if err != nil { - return fmt.Errorf("unable to fetch nodes: %w", - err) - } - - edge, err := getAndBuildEdgeInfo( - ctx, db, s.cfg.ChainHash, row.GraphChannel, - node1.PubKeyBytes, node2.PubKeyBytes, - ) - if err != nil { - return fmt.Errorf("unable to build "+ - "channel info: %w", err) - } - - dbPol1, dbPol2, err := extractChannelPolicies(row) - if err != nil { - return fmt.Errorf("unable to extract channel "+ - "policies: %w", err) - } - - p1, p2, err := getAndBuildChanPolicies( - ctx, db, dbPol1, dbPol2, edge.ChannelID, - node1.PubKeyBytes, node2.PubKeyBytes, - ) - if err != nil { - return fmt.Errorf("unable to build channel "+ - "policies: %w", err) - } + channelRows = append(channelRows, row) + return nil + } - edges[edge.ChannelID] = ChannelEdge{ - Info: edge, - Policy1: p1, - Policy2: p2, - Node1: node1, - Node2: node2, - } + err := s.forEachChanWithPoliciesInSCIDList( + ctx, db, chanCallBack, chanIDs, + ) + if err != nil { + return err + } + if len(channelRows) == 0 { return nil } - return s.forEachChanWithPoliciesInSCIDList( - ctx, db, chanCallBack, chanIDs, + // Batch build all channel edges. + chans, err := batchBuildChannelEdges( + ctx, s.cfg, db, channelRows, ) + if err != nil { + return fmt.Errorf("unable to build channel edges: %w", + err) + } + + for _, c := range chans { + edges[c.Info.ChannelID] = c + } + + return err }, func() { clear(edges) }) @@ -2363,31 +2310,12 @@ func (s *SQLStore) PruneGraph(spentOutputs []*wire.OutPoint, prunedNodes []route.Vertex ) err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error { - var chansToDelete []int64 - - // Define the callback function for processing each channel. + // First, collect all channel rows that need to be pruned. + var channelRows []sqlc.GetChannelsByOutpointsRow channelCallback := func(ctx context.Context, row sqlc.GetChannelsByOutpointsRow) error { - node1, node2, err := buildNodeVertices( - row.Node1Pubkey, row.Node2Pubkey, - ) - if err != nil { - return err - } - - info, err := getAndBuildEdgeInfo( - ctx, db, s.cfg.ChainHash, row.GraphChannel, - node1, node2, - ) - if err != nil { - return err - } - - closedChans = append(closedChans, info) - chansToDelete = append( - chansToDelete, row.GraphChannel.ID, - ) + channelRows = append(channelRows, row) return nil } @@ -2400,6 +2328,32 @@ func (s *SQLStore) PruneGraph(spentOutputs []*wire.OutPoint, "outpoints: %w", err) } + if len(channelRows) == 0 { + // There are no channels to prune. So we can exit early + // after updating the prune log. + err = db.UpsertPruneLogEntry( + ctx, sqlc.UpsertPruneLogEntryParams{ + BlockHash: blockHash[:], + BlockHeight: int64(blockHeight), + }, + ) + if err != nil { + return fmt.Errorf("unable to insert prune log "+ + "entry: %w", err) + } + + return nil + } + + // Batch build all channel edges for pruning. + var chansToDelete []int64 + closedChans, chansToDelete, err = batchBuildChannelInfo( + ctx, s.cfg, db, channelRows, + ) + if err != nil { + return err + } + err = s.deleteChannels(ctx, db, chansToDelete) if err != nil { return fmt.Errorf("unable to delete channels: %w", err) @@ -2658,27 +2612,29 @@ func (s *SQLStore) DisconnectBlockAtHeight(height uint32) ( return fmt.Errorf("unable to fetch channels: %w", err) } - chanIDsToDelete := make([]int64, len(rows)) - for i, row := range rows { - node1, node2, err := buildNodeVertices( - row.Node1PubKey, row.Node2PubKey, - ) - if err != nil { - return err - } - - channel, err := getAndBuildEdgeInfo( - ctx, db, s.cfg.ChainHash, row.GraphChannel, - node1, node2, + if len(rows) == 0 { + // No channels to disconnect, but still clean up prune + // log. + return db.DeletePruneLogEntriesInRange( + ctx, sqlc.DeletePruneLogEntriesInRangeParams{ + StartHeight: int64(height), + EndHeight: int64( + endShortChanID.BlockHeight, + ), + }, ) - if err != nil { - return err - } + } - chanIDsToDelete[i] = row.GraphChannel.ID - removedChans = append(removedChans, channel) + // Batch build all channel edges for disconnection. + channelEdges, chanIDsToDelete, err := batchBuildChannelInfo( + ctx, s.cfg, db, rows, + ) + if err != nil { + return err } + removedChans = channelEdges + err = s.deleteChannels(ctx, db, chanIDsToDelete) if err != nil { return fmt.Errorf("unable to delete channels: %w", err) @@ -3230,7 +3186,7 @@ func getNodeByPubKey(ctx context.Context, db SQLQueries, return 0, nil, fmt.Errorf("unable to fetch node: %w", err) } - node, err := buildNode(ctx, db, &dbNode) + node, err := buildNode(ctx, db, dbNode) if err != nil { return 0, nil, fmt.Errorf("unable to build node: %w", err) } @@ -3255,7 +3211,7 @@ func buildCacheableChannelInfo(scid []byte, capacity int64, node1Pub, // record. The node's features, addresses and extra signed fields are also // fetched from the database and set on the node. func buildNode(ctx context.Context, db SQLQueries, - dbNode *sqlc.GraphNode) (*models.LightningNode, error) { + dbNode sqlc.GraphNode) (*models.LightningNode, error) { // NOTE: buildNode is only used to load the data for a single node, and // so no paged queries will be performed. This means that it's ok to @@ -3275,7 +3231,7 @@ func buildNode(ctx context.Context, db SQLQueries, // from the provided sqlc.GraphNode and batchNodeData. If the node does have // features/addresses/extra fields, then the corresponding fields are expected // to be present in the batchNodeData. -func buildNodeWithBatchData(dbNode *sqlc.GraphNode, +func buildNodeWithBatchData(dbNode sqlc.GraphNode, batchData *batchNodeData) (*models.LightningNode, error) { if dbNode.Version != int16(ProtocolV1) { @@ -3363,7 +3319,7 @@ func forEachNodeInBatch(ctx context.Context, cfg *sqldb.QueryConfig, } for _, dbNode := range nodes { - node, err := buildNodeWithBatchData(&dbNode, batchData) + node, err := buildNodeWithBatchData(dbNode, batchData) if err != nil { return fmt.Errorf("unable to build node(id=%d): %w", dbNode.ID, err) @@ -4228,25 +4184,6 @@ func buildChanPolicy(dbPolicy sqlc.GraphChannelPolicy, channelID uint64, }, nil } -// buildNodes builds the models.LightningNode instances for the -// given row which is expected to be a sqlc type that contains node information. -func buildNodes(ctx context.Context, db SQLQueries, dbNode1, - dbNode2 sqlc.GraphNode) (*models.LightningNode, *models.LightningNode, - error) { - - node1, err := buildNode(ctx, db, &dbNode1) - if err != nil { - return nil, nil, err - } - - node2, err := buildNode(ctx, db, &dbNode2) - if err != nil { - return nil, nil, err - } - - return node1, node2, nil -} - // extractChannelPolicies extracts the sqlc.GraphChannelPolicy records from the give // row which is expected to be a sqlc type that contains channel policy // information. It returns two policies, which may be nil if the policy @@ -5089,7 +5026,7 @@ func forEachNodePaginated(ctx context.Context, cfg *sqldb.QueryConfig, processItem := func(ctx context.Context, dbNode sqlc.GraphNode, batchData *batchNodeData) error { - node, err := buildNodeWithBatchData(&dbNode, batchData) + node, err := buildNodeWithBatchData(dbNode, batchData) if err != nil { return fmt.Errorf("unable to build "+ "node(id=%d): %w", dbNode.ID, err) @@ -5297,3 +5234,208 @@ func buildDirectedChannel(chain chainhash.Hash, nodeID int64, return directedChannel, nil } + +// batchBuildChannelEdges builds a slice of ChannelEdge instances from the +// provided rows. It uses batch loading for channels, policies, and nodes. +func batchBuildChannelEdges[T sqlc.ChannelAndNodes](ctx context.Context, + cfg *SQLStoreConfig, db SQLQueries, rows []T) ([]ChannelEdge, error) { + + var ( + channelIDs = make([]int64, len(rows)) + policyIDs = make([]int64, 0, len(rows)*2) + nodeIDs = make([]int64, 0, len(rows)*2) + + // nodeIDSet is used to ensure we only collect unique node IDs. + nodeIDSet = make(map[int64]bool) + + // edges will hold the final channel edges built from the rows. + edges = make([]ChannelEdge, 0, len(rows)) + ) + + // Collect all IDs needed for batch loading. + for i, row := range rows { + channelIDs[i] = row.Channel().ID + + // Collect policy IDs + dbPol1, dbPol2, err := extractChannelPolicies(row) + if err != nil { + return nil, fmt.Errorf("unable to extract channel "+ + "policies: %w", err) + } + if dbPol1 != nil { + policyIDs = append(policyIDs, dbPol1.ID) + } + if dbPol2 != nil { + policyIDs = append(policyIDs, dbPol2.ID) + } + + var ( + node1ID = row.Node1().ID + node2ID = row.Node2().ID + ) + + // Collect unique node IDs. + if !nodeIDSet[node1ID] { + nodeIDs = append(nodeIDs, node1ID) + nodeIDSet[node1ID] = true + } + + if !nodeIDSet[node2ID] { + nodeIDs = append(nodeIDs, node2ID) + nodeIDSet[node2ID] = true + } + } + + // Batch the data for all the channels and policies. + channelBatchData, err := batchLoadChannelData( + ctx, cfg.QueryCfg, db, channelIDs, policyIDs, + ) + if err != nil { + return nil, fmt.Errorf("unable to batch load channel and "+ + "policy data: %w", err) + } + + // Batch the data for all the nodes. + nodeBatchData, err := batchLoadNodeData(ctx, cfg.QueryCfg, db, nodeIDs) + if err != nil { + return nil, fmt.Errorf("unable to batch load node data: %w", + err) + } + + // Build all channel edges using batch data. + for _, row := range rows { + // Build nodes using batch data. + node1, err := buildNodeWithBatchData(row.Node1(), nodeBatchData) + if err != nil { + return nil, fmt.Errorf("unable to build node1: %w", err) + } + + node2, err := buildNodeWithBatchData(row.Node2(), nodeBatchData) + if err != nil { + return nil, fmt.Errorf("unable to build node2: %w", err) + } + + // Build channel info using batch data. + channel, err := buildEdgeInfoWithBatchData( + cfg.ChainHash, row.Channel(), node1.PubKeyBytes, + node2.PubKeyBytes, channelBatchData, + ) + if err != nil { + return nil, fmt.Errorf("unable to build channel "+ + "info: %w", err) + } + + // Extract and build policies using batch data. + dbPol1, dbPol2, err := extractChannelPolicies(row) + if err != nil { + return nil, fmt.Errorf("unable to extract channel "+ + "policies: %w", err) + } + + p1, p2, err := buildChanPoliciesWithBatchData( + dbPol1, dbPol2, channel.ChannelID, + node1.PubKeyBytes, node2.PubKeyBytes, channelBatchData, + ) + if err != nil { + return nil, fmt.Errorf("unable to build channel "+ + "policies: %w", err) + } + + edges = append(edges, ChannelEdge{ + Info: channel, + Policy1: p1, + Policy2: p2, + Node1: node1, + Node2: node2, + }) + } + + return edges, nil +} + +// batchBuildChannelInfo builds a slice of models.ChannelEdgeInfo +// instances from the provided rows using batch loading for channel data. +func batchBuildChannelInfo[T sqlc.ChannelAndNodeIDs](ctx context.Context, + cfg *SQLStoreConfig, db SQLQueries, rows []T) ( + []*models.ChannelEdgeInfo, []int64, error) { + + if len(rows) == 0 { + return nil, nil, nil + } + + // Collect all the channel IDs needed for batch loading. + channelIDs := make([]int64, len(rows)) + for i, row := range rows { + channelIDs[i] = row.Channel().ID + } + + // Batch load the channel data. + channelBatchData, err := batchLoadChannelData( + ctx, cfg.QueryCfg, db, channelIDs, nil, + ) + if err != nil { + return nil, nil, fmt.Errorf("unable to batch load channel "+ + "data: %w", err) + } + + // Build all channel edges using batch data. + edges := make([]*models.ChannelEdgeInfo, 0, len(rows)) + for _, row := range rows { + node1, node2, err := buildNodeVertices( + row.Node1Pub(), row.Node2Pub(), + ) + if err != nil { + return nil, nil, err + } + + // Build channel info using batch data + info, err := buildEdgeInfoWithBatchData( + cfg.ChainHash, row.Channel(), node1, node2, + channelBatchData, + ) + if err != nil { + return nil, nil, err + } + + edges = append(edges, info) + } + + return edges, channelIDs, nil +} + +// handleZombieMarking is a helper function that handles the logic of +// marking a channel as a zombie in the database. It takes into account whether +// we are in strict zombie pruning mode, and adjusts the node public keys +// accordingly based on the last update timestamps of the channel policies. +func handleZombieMarking(ctx context.Context, db SQLQueries, + row sqlc.GetChannelsBySCIDWithPoliciesRow, info *models.ChannelEdgeInfo, + strictZombiePruning bool, scid uint64) error { + + nodeKey1, nodeKey2 := info.NodeKey1Bytes, info.NodeKey2Bytes + + if strictZombiePruning { + var e1UpdateTime, e2UpdateTime *time.Time + if row.Policy1LastUpdate.Valid { + e1Time := time.Unix(row.Policy1LastUpdate.Int64, 0) + e1UpdateTime = &e1Time + } + if row.Policy2LastUpdate.Valid { + e2Time := time.Unix(row.Policy2LastUpdate.Int64, 0) + e2UpdateTime = &e2Time + } + + nodeKey1, nodeKey2 = makeZombiePubkeys( + info.NodeKey1Bytes, info.NodeKey2Bytes, e1UpdateTime, + e2UpdateTime, + ) + } + + return db.UpsertZombieChannel( + ctx, sqlc.UpsertZombieChannelParams{ + Version: int16(ProtocolV1), + Scid: channelIDToBytes(scid), + NodeKey1: nodeKey1[:], + NodeKey2: nodeKey2[:], + }, + ) +} diff --git a/healthcheck/go.mod b/healthcheck/go.mod index f459a921461..ac363500147 100644 --- a/healthcheck/go.mod +++ b/healthcheck/go.mod @@ -24,4 +24,4 @@ require ( gopkg.in/yaml.v3 v3.0.1 // indirect ) -go 1.23.10 +go 1.23.12 diff --git a/kvdb/go.mod b/kvdb/go.mod index b711afcb1f7..e37ff83981c 100644 --- a/kvdb/go.mod +++ b/kvdb/go.mod @@ -147,4 +147,4 @@ replace github.com/ulikunitz/xz => github.com/ulikunitz/xz v0.5.11 // https://deps.dev/advisory/OSV/GO-2021-0053?from=%2Fgo%2Fgithub.com%252Fgogo%252Fprotobuf%2Fv1.3.1 replace github.com/gogo/protobuf => github.com/gogo/protobuf v1.3.2 -go 1.23.10 +go 1.23.12 diff --git a/lnrpc/Dockerfile b/lnrpc/Dockerfile index 05f916618b8..be8c9f7c0d6 100644 --- a/lnrpc/Dockerfile +++ b/lnrpc/Dockerfile @@ -1,6 +1,6 @@ # If you change this please also update GO_VERSION in Makefile (then run # `make lint` to see where else it needs to be updated as well). -FROM golang:1.23.10-bookworm +FROM golang:1.23.12-bookworm RUN apt-get update && apt-get install -y \ git \ diff --git a/lnrpc/gen_protos_docker.sh b/lnrpc/gen_protos_docker.sh index e72486359c7..2253bdc0681 100755 --- a/lnrpc/gen_protos_docker.sh +++ b/lnrpc/gen_protos_docker.sh @@ -6,7 +6,7 @@ set -e DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" # golang docker image version used in this script. -GO_IMAGE=docker.io/library/golang:1.23.10-alpine +GO_IMAGE=docker.io/library/golang:1.23.12-alpine PROTOBUF_VERSION=$(docker run --rm -v $DIR/../:/lnd -w /lnd $GO_IMAGE \ go list -f '{{.Version}}' -m google.golang.org/protobuf) diff --git a/lnrpc/routerrpc/router_backend.go b/lnrpc/routerrpc/router_backend.go index 34cd7047e7b..022f4619acc 100644 --- a/lnrpc/routerrpc/router_backend.go +++ b/lnrpc/routerrpc/router_backend.go @@ -14,7 +14,6 @@ import ( "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/wire" sphinx "github.com/lightningnetwork/lightning-onion" - "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/feature" "github.com/lightningnetwork/lnd/fn/v2" @@ -22,6 +21,7 @@ import ( "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" + paymentsdb "github.com/lightningnetwork/lnd/payments/db" "github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/routing" "github.com/lightningnetwork/lnd/routing/route" @@ -1488,7 +1488,7 @@ func UnmarshalAMP(reqAMP *lnrpc.AMPRecord) (*record.AMP, error) { // MarshalHTLCAttempt constructs an RPC HTLCAttempt from the db representation. func (r *RouterBackend) MarshalHTLCAttempt( - htlc channeldb.HTLCAttempt) (*lnrpc.HTLCAttempt, error) { + htlc paymentsdb.HTLCAttempt) (*lnrpc.HTLCAttempt, error) { route, err := r.MarshallRoute(&htlc.Route) if err != nil { @@ -1529,7 +1529,7 @@ func (r *RouterBackend) MarshalHTLCAttempt( // marshallHtlcFailure marshalls htlc fail info from the database to its rpc // representation. -func marshallHtlcFailure(failure *channeldb.HTLCFailInfo) (*lnrpc.Failure, +func marshallHtlcFailure(failure *paymentsdb.HTLCFailInfo) (*lnrpc.Failure, error) { rpcFailure := &lnrpc.Failure{ @@ -1537,16 +1537,16 @@ func marshallHtlcFailure(failure *channeldb.HTLCFailInfo) (*lnrpc.Failure, } switch failure.Reason { - case channeldb.HTLCFailUnknown: + case paymentsdb.HTLCFailUnknown: rpcFailure.Code = lnrpc.Failure_UNKNOWN_FAILURE - case channeldb.HTLCFailUnreadable: + case paymentsdb.HTLCFailUnreadable: rpcFailure.Code = lnrpc.Failure_UNREADABLE_FAILURE - case channeldb.HTLCFailInternal: + case paymentsdb.HTLCFailInternal: rpcFailure.Code = lnrpc.Failure_INTERNAL_FAILURE - case channeldb.HTLCFailMessage: + case paymentsdb.HTLCFailMessage: err := marshallWireError(failure.Message, rpcFailure) if err != nil { return nil, err @@ -1743,7 +1743,7 @@ func marshallChannelUpdate(update *lnwire.ChannelUpdate1) *lnrpc.ChannelUpdate { } // MarshallPayment marshall a payment to its rpc representation. -func (r *RouterBackend) MarshallPayment(payment *channeldb.MPPayment) ( +func (r *RouterBackend) MarshallPayment(payment *paymentsdb.MPPayment) ( *lnrpc.Payment, error) { // Fetch the payment's preimage and the total paid in fees. @@ -1813,11 +1813,11 @@ func (r *RouterBackend) MarshallPayment(payment *channeldb.MPPayment) ( // convertPaymentStatus converts a channeldb.PaymentStatus to the type expected // by the RPC. -func convertPaymentStatus(dbStatus channeldb.PaymentStatus, useInit bool) ( +func convertPaymentStatus(dbStatus paymentsdb.PaymentStatus, useInit bool) ( lnrpc.Payment_PaymentStatus, error) { switch dbStatus { - case channeldb.StatusInitiated: + case paymentsdb.StatusInitiated: // If the client understands the new status, return it. if useInit { return lnrpc.Payment_INITIATED, nil @@ -1826,13 +1826,13 @@ func convertPaymentStatus(dbStatus channeldb.PaymentStatus, useInit bool) ( // Otherwise remain the old behavior. return lnrpc.Payment_IN_FLIGHT, nil - case channeldb.StatusInFlight: + case paymentsdb.StatusInFlight: return lnrpc.Payment_IN_FLIGHT, nil - case channeldb.StatusSucceeded: + case paymentsdb.StatusSucceeded: return lnrpc.Payment_SUCCEEDED, nil - case channeldb.StatusFailed: + case paymentsdb.StatusFailed: return lnrpc.Payment_FAILED, nil default: @@ -1842,7 +1842,7 @@ func convertPaymentStatus(dbStatus channeldb.PaymentStatus, useInit bool) ( // marshallPaymentFailureReason marshalls the failure reason to the corresponding rpc // type. -func marshallPaymentFailureReason(reason *channeldb.FailureReason) ( +func marshallPaymentFailureReason(reason *paymentsdb.FailureReason) ( lnrpc.PaymentFailureReason, error) { if reason == nil { @@ -1850,22 +1850,22 @@ func marshallPaymentFailureReason(reason *channeldb.FailureReason) ( } switch *reason { - case channeldb.FailureReasonTimeout: + case paymentsdb.FailureReasonTimeout: return lnrpc.PaymentFailureReason_FAILURE_REASON_TIMEOUT, nil - case channeldb.FailureReasonNoRoute: + case paymentsdb.FailureReasonNoRoute: return lnrpc.PaymentFailureReason_FAILURE_REASON_NO_ROUTE, nil - case channeldb.FailureReasonError: + case paymentsdb.FailureReasonError: return lnrpc.PaymentFailureReason_FAILURE_REASON_ERROR, nil - case channeldb.FailureReasonPaymentDetails: + case paymentsdb.FailureReasonPaymentDetails: return lnrpc.PaymentFailureReason_FAILURE_REASON_INCORRECT_PAYMENT_DETAILS, nil - case channeldb.FailureReasonInsufficientBalance: + case paymentsdb.FailureReasonInsufficientBalance: return lnrpc.PaymentFailureReason_FAILURE_REASON_INSUFFICIENT_BALANCE, nil - case channeldb.FailureReasonCanceled: + case paymentsdb.FailureReasonCanceled: return lnrpc.PaymentFailureReason_FAILURE_REASON_CANCELED, nil } diff --git a/lnrpc/routerrpc/router_server.go b/lnrpc/routerrpc/router_server.go index 843222eee4b..894edf28aa6 100644 --- a/lnrpc/routerrpc/router_server.go +++ b/lnrpc/routerrpc/router_server.go @@ -15,7 +15,6 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "github.com/lightningnetwork/lnd/aliasmgr" - "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lnrpc/invoicesrpc" @@ -927,7 +926,7 @@ func (s *Server) SendToRouteV2(ctx context.Context, return nil, err } - var attempt *channeldb.HTLCAttempt + var attempt *paymentsdb.HTLCAttempt // Pass route to the router. This call returns the full htlc attempt // information as it is stored in the database. It is possible that both @@ -1449,17 +1448,21 @@ func (s *Server) trackPaymentStream(context context.Context, // No more payment updates. return nil } - result := item.(*channeldb.MPPayment) + result, ok := item.(*paymentsdb.MPPayment) + if !ok { + return fmt.Errorf("unexpected payment type: %T", + item) + } log.Tracef("Payment %v updated to state %v", result.Info.PaymentIdentifier, result.Status) // Skip in-flight updates unless requested. if noInflightUpdates { - if result.Status == channeldb.StatusInitiated { + if result.Status == paymentsdb.StatusInitiated { continue } - if result.Status == channeldb.StatusInFlight { + if result.Status == paymentsdb.StatusInFlight { continue } } diff --git a/lnrpc/routerrpc/router_server_test.go b/lnrpc/routerrpc/router_server_test.go index bc5a7f16d9a..70e3fbda89e 100644 --- a/lnrpc/routerrpc/router_server_test.go +++ b/lnrpc/routerrpc/router_server_test.go @@ -6,10 +6,10 @@ import ( "time" "github.com/btcsuite/btcd/btcec/v2" - "github.com/lightningnetwork/lnd/channeldb" graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lnwire" + paymentsdb "github.com/lightningnetwork/lnd/payments/db" "github.com/lightningnetwork/lnd/queue" "github.com/lightningnetwork/lnd/routing" "github.com/lightningnetwork/lnd/routing/route" @@ -129,13 +129,13 @@ func TestTrackPaymentsInflightUpdates(t *testing.T) { }() // Enqueue some payment updates on the mock. - towerMock.queue.ChanIn() <- &channeldb.MPPayment{ - Info: &channeldb.PaymentCreationInfo{}, - Status: channeldb.StatusInFlight, + towerMock.queue.ChanIn() <- &paymentsdb.MPPayment{ + Info: &paymentsdb.PaymentCreationInfo{}, + Status: paymentsdb.StatusInFlight, } - towerMock.queue.ChanIn() <- &channeldb.MPPayment{ - Info: &channeldb.PaymentCreationInfo{}, - Status: channeldb.StatusSucceeded, + towerMock.queue.ChanIn() <- &paymentsdb.MPPayment{ + Info: &paymentsdb.PaymentCreationInfo{}, + Status: paymentsdb.StatusSucceeded, } // Wait until there's 2 updates or the deadline is exceeded. @@ -191,13 +191,13 @@ func TestTrackPaymentsNoInflightUpdates(t *testing.T) { }() // Enqueue some payment updates on the mock. - towerMock.queue.ChanIn() <- &channeldb.MPPayment{ - Info: &channeldb.PaymentCreationInfo{}, - Status: channeldb.StatusInFlight, + towerMock.queue.ChanIn() <- &paymentsdb.MPPayment{ + Info: &paymentsdb.PaymentCreationInfo{}, + Status: paymentsdb.StatusInFlight, } - towerMock.queue.ChanIn() <- &channeldb.MPPayment{ - Info: &channeldb.PaymentCreationInfo{}, - Status: channeldb.StatusSucceeded, + towerMock.queue.ChanIn() <- &paymentsdb.MPPayment{ + Info: &paymentsdb.PaymentCreationInfo{}, + Status: paymentsdb.StatusSucceeded, } // Wait until there's 1 update or the deadline is exceeded. diff --git a/lnwallet/aux_signer.go b/lnwallet/aux_signer.go index 90a4325f60e..79a7ca1dc09 100644 --- a/lnwallet/aux_signer.go +++ b/lnwallet/aux_signer.go @@ -10,9 +10,20 @@ import ( "github.com/lightningnetwork/lnd/tlv" ) -// htlcCustomSigType is the TLV type that is used to encode the custom HTLC -// signatures within the custom data for an existing HTLC. -var htlcCustomSigType tlv.TlvType65543 +var ( + // htlcCustomSigType is the TLV type that is used to encode the custom + // HTLC signatures within the custom data for an existing HTLC. + htlcCustomSigType tlv.TlvType65543 + + // NoOpHtlcTLVEntry is the TLV that that's used in the update_add_htlc + // message to indicate the presence of a noop HTLC. This has no encoded + // value, its presence is used to indicate that the HTLC is a noop. + NoOpHtlcTLVEntry tlv.TlvType65544 +) + +// NoOpHtlcTLVType is the (golang) type of the TLV record that's used to signal +// that an HTLC should be a noop HTLC. +type NoOpHtlcTLVType = tlv.TlvType65544 // AuxHtlcView is a struct that contains a safe copy of an HTLC view that can // be used by aux components. @@ -116,6 +127,18 @@ func (a *AuxHtlcDescriptor) AddHeight( return a.addCommitHeightLocal } +// IsAdd checks if the entry type of the Aux HTLC Descriptor is an add type. +func (a *AuxHtlcDescriptor) IsAdd() bool { + switch a.EntryType { + case Add: + fallthrough + case NoOpAdd: + return true + default: + return false + } +} + // RemoveHeight returns the height at which the HTLC was removed from the // commitment chain. The height is returned based on the chain the HTLC is being // removed from (local or remote chain). diff --git a/lnwallet/channel.go b/lnwallet/channel.go index ee45bf943d7..29de01618ee 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -551,6 +551,12 @@ func (lc *LightningChannel) diskHtlcToPayDesc(feeRate chainfee.SatPerKWeight, remoteOutputIndex = htlc.OutputIndex } + customRecords := htlc.CustomRecords.Copy() + + entryType := lc.entryTypeForHtlc( + customRecords, lc.channelState.ChanType, + ) + // With the scripts reconstructed (depending on if this is our commit // vs theirs or a pending commit for the remote party), we can now // re-create the original payment descriptor. @@ -559,7 +565,7 @@ func (lc *LightningChannel) diskHtlcToPayDesc(feeRate chainfee.SatPerKWeight, RHash: htlc.RHash, Timeout: htlc.RefundTimeout, Amount: htlc.Amt, - EntryType: Add, + EntryType: entryType, HtlcIndex: htlc.HtlcIndex, LogIndex: htlc.LogIndex, OnionBlob: htlc.OnionBlob, @@ -570,7 +576,7 @@ func (lc *LightningChannel) diskHtlcToPayDesc(feeRate chainfee.SatPerKWeight, theirPkScript: theirP2WSH, theirWitnessScript: theirWitnessScript, BlindingPoint: htlc.BlindingPoint, - CustomRecords: htlc.CustomRecords.Copy(), + CustomRecords: customRecords, }, nil } @@ -1100,6 +1106,10 @@ func (lc *LightningChannel) logUpdateToPayDesc(logUpdate *channeldb.LogUpdate, }, } + pd.EntryType = lc.entryTypeForHtlc( + pd.CustomRecords, lc.channelState.ChanType, + ) + isDustRemote := HtlcIsDust( lc.channelState.ChanType, false, lntypes.Remote, feeRate, wireMsg.Amount.ToSatoshis(), remoteDustLimit, @@ -1336,6 +1346,10 @@ func (lc *LightningChannel) remoteLogUpdateToPayDesc(logUpdate *channeldb.LogUpd }, } + pd.EntryType = lc.entryTypeForHtlc( + pd.CustomRecords, lc.channelState.ChanType, + ) + // We don't need to generate an htlc script yet. This will be // done once we sign our remote commitment. @@ -1736,7 +1750,7 @@ func (lc *LightningChannel) restorePendingRemoteUpdates( // but this Add restoration was a no-op as every single one of // these Adds was already restored since they're all incoming // htlcs on the local commitment. - if payDesc.EntryType == Add { + if payDesc.isAdd() { continue } @@ -1881,7 +1895,7 @@ func (lc *LightningChannel) restorePendingLocalUpdates( } switch payDesc.EntryType { - case Add: + case Add, NoOpAdd: // The HtlcIndex of the added HTLC _must_ be equal to // the log's htlcCounter at this point. If it is not we // panic to catch this. @@ -2993,6 +3007,22 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ) if rmvHeight == 0 { switch { + // If this a noop add, then when we settle the + // HTLC, we may credit the sender with the + // amount again, thus making it a noop. Noop + // HTLCs are only triggered by external software + // using the AuxComponents and only for channels + // that use the custom tapscript root. The + // criteria about whether the noop will be + // effective is whether the receiver is already + // sitting above reserve. + case entry.EntryType == Settle && + addEntry.EntryType == NoOpAdd: + + lc.evaluateNoOpHtlc( + entry, party, &balanceDeltas, + ) + // If an incoming HTLC is being settled, then // this means that the preimage has been // received by the settling party Therefore, we @@ -3030,7 +3060,7 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, liveAdds := fn.Filter( view.Updates.GetForParty(party), func(pd *paymentDescriptor) bool { - isAdd := pd.EntryType == Add + isAdd := pd.isAdd() shouldSkip := skip.GetForParty(party). Contains(pd.HtlcIndex) @@ -3069,7 +3099,7 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, // corresponding to whoseCommitmentChain. isUncommitted := func(update *paymentDescriptor) bool { switch update.EntryType { - case Add: + case Add, NoOpAdd: return update.addCommitHeights.GetForParty( whoseCommitChain, ) == 0 @@ -3145,6 +3175,92 @@ func (lc *LightningChannel) fetchParent(entry *paymentDescriptor, return addEntry, nil } +// balanceAboveReserve checks if the balance for the provided party is above the +// configured reserve. It also uses the balance delta for the party, to account +// for entry amounts that have been processed already. +func balanceAboveReserve(party lntypes.ChannelParty, delta int64, + channel *channeldb.OpenChannel) bool { + + // We're going to access the channel state, so let's make sure we're + // holding the lock. + channel.RLock() + defer channel.RUnlock() + + // For calculating whether a party is above reserve we are going to + // use the channel state local/remote balance of the corresponding + // commitment. This balance corresponds to the balance of each party + // after the most recent revocation. That's the balance on top of which + // we may apply the balance delta of the currently processed HTLCs. It + // is important for the calculated balance to match between us and our + // peer, as any disagreement over the balances here can lead to a force + // closure. + c := channel + + localReserve := lnwire.NewMSatFromSatoshis(c.LocalChanCfg.ChanReserve) + remoteReserve := lnwire.NewMSatFromSatoshis(c.RemoteChanCfg.ChanReserve) + + switch { + case party.IsLocal(): + // For the local party we'll consult the local balance of the + // local commitment. Then we'll correctly add the delta based on + // whether it's negative or not. + totalLocal := c.LocalCommitment.LocalBalance + if delta >= 0 { + totalLocal += lnwire.MilliSatoshi(delta) + } else { + totalLocal -= lnwire.MilliSatoshi(-1 * delta) + } + + return totalLocal > localReserve + + case party.IsRemote(): + // For the remote party we'll consult the remote balance of the + // remote commitment. Then we'll correctly add the delta based + // on whether it's negative or not. + totalRemote := c.RemoteCommitment.RemoteBalance + if delta >= 0 { + totalRemote += lnwire.MilliSatoshi(delta) + } else { + totalRemote -= lnwire.MilliSatoshi(-1 * delta) + } + + return totalRemote > remoteReserve + } + + return false +} + +// evaluateNoOpHtlc applies the balance delta based on whether the NoOp HTLC is +// considered effective. This depends on whether the receiver is already above +// the channel reserve. +func (lc *LightningChannel) evaluateNoOpHtlc(entry *paymentDescriptor, + party lntypes.ChannelParty, balanceDeltas *lntypes.Dual[int64]) { + + channel := lc.channelState + delta := balanceDeltas.GetForParty(party) + + // If the receiver has existing balance above reserve then we go ahead + // with crediting the amount back to the sender. Otherwise we give the + // amount to the receiver. We do this because the receiver needs some + // above reserve balance to anchor the AuxBlob. We also pass in the so + // far calculated delta for the party, as that's effectively part of + // their balance within this view computation. + if balanceAboveReserve(party, delta, channel) { + party = party.CounterParty() + + // The noop is effective, meaning that the settlement will + // credit the amount back to the sender. Let's mark this as it + // may be needed later when processing the settle entry, where + // we won't be able to perform the above check again. + entry.noOpSettle = true + } + + d := int64(entry.Amount) + balanceDeltas.ModifyForParty(party, func(acc int64) int64 { + return acc + d + }) +} + // generateRemoteHtlcSigJobs generates a series of HTLC signature jobs for the // sig pool, along with a channel that if closed, will cancel any jobs after // they have been submitted to the sigPool. This method is to be used when @@ -3833,7 +3949,7 @@ func (lc *LightningChannel) validateCommitmentSanity(theirLogCounter, // Go through all updates, checking that they don't violate the // channel constraints. for _, entry := range updates { - if entry.EntryType == Add { + if entry.isAdd() { // An HTLC is being added, this will add to the // number and amount in flight. amtInFlight += entry.Amount @@ -4668,6 +4784,15 @@ func (lc *LightningChannel) computeView(view *HtlcView, if whoseCommitChain == lntypes.Local && u.EntryType == Settle { + // If this settle was a result of an + // effective noop add entry, then we + // don't need to record the amount as it + // was never sent over to the other + // side. + if u.noOpSettle { + continue + } + lc.recordSettlement(party, u.Amount) } } @@ -5712,7 +5837,7 @@ func (lc *LightningChannel) ReceiveRevocation(revMsg *lnwire.RevokeAndAck) ( // don't re-forward any already processed HTLC's after a // restart. switch { - case pd.EntryType == Add && committedAdd && shouldFwdAdd: + case pd.isAdd() && committedAdd && shouldFwdAdd: // Construct a reference specifying the location that // this forwarded Add will be written in the forwarding // package constructed at this remote height. @@ -5731,7 +5856,7 @@ func (lc *LightningChannel) ReceiveRevocation(revMsg *lnwire.RevokeAndAck) ( addUpdatesToForward, pd.toLogUpdate(), ) - case pd.EntryType != Add && committedRmv && shouldFwdRmv: + case !pd.isAdd() && committedRmv && shouldFwdRmv: // Construct a reference specifying the location that // this forwarded Settle/Fail will be written in the // forwarding package constructed at this remote height. @@ -5970,7 +6095,7 @@ func (lc *LightningChannel) GetDustSum(whoseCommit lntypes.ChannelParty, // Grab all of our HTLCs and evaluate against the dust limit. for e := lc.updateLogs.Local.Front(); e != nil; e = e.Next() { pd := e.Value - if pd.EntryType != Add { + if !pd.isAdd() { continue } @@ -5989,7 +6114,7 @@ func (lc *LightningChannel) GetDustSum(whoseCommit lntypes.ChannelParty, // Grab all of their HTLCs and evaluate against the dust limit. for e := lc.updateLogs.Remote.Front(); e != nil; e = e.Next() { pd := e.Value - if pd.EntryType != Add { + if !pd.isAdd() { continue } @@ -6062,9 +6187,14 @@ func (lc *LightningChannel) MayAddOutgoingHtlc(amt lnwire.MilliSatoshi) error { func (lc *LightningChannel) htlcAddDescriptor(htlc *lnwire.UpdateAddHTLC, openKey *models.CircuitKey) *paymentDescriptor { + customRecords := htlc.CustomRecords.Copy() + entryType := lc.entryTypeForHtlc( + customRecords, lc.channelState.ChanType, + ) + return &paymentDescriptor{ ChanID: htlc.ChanID, - EntryType: Add, + EntryType: entryType, RHash: PaymentHash(htlc.PaymentHash), Timeout: htlc.Expiry, Amount: htlc.Amount, @@ -6073,7 +6203,7 @@ func (lc *LightningChannel) htlcAddDescriptor(htlc *lnwire.UpdateAddHTLC, OnionBlob: htlc.OnionBlob, OpenCircuitKey: openKey, BlindingPoint: htlc.BlindingPoint, - CustomRecords: htlc.CustomRecords.Copy(), + CustomRecords: customRecords, } } @@ -6126,9 +6256,14 @@ func (lc *LightningChannel) ReceiveHTLC(htlc *lnwire.UpdateAddHTLC) (uint64, lc.updateLogs.Remote.htlcCounter) } + customRecords := htlc.CustomRecords.Copy() + entryType := lc.entryTypeForHtlc( + customRecords, lc.channelState.ChanType, + ) + pd := &paymentDescriptor{ ChanID: htlc.ChanID, - EntryType: Add, + EntryType: entryType, RHash: PaymentHash(htlc.PaymentHash), Timeout: htlc.Expiry, Amount: htlc.Amount, @@ -6136,7 +6271,7 @@ func (lc *LightningChannel) ReceiveHTLC(htlc *lnwire.UpdateAddHTLC) (uint64, HtlcIndex: lc.updateLogs.Remote.htlcCounter, OnionBlob: htlc.OnionBlob, BlindingPoint: htlc.BlindingPoint, - CustomRecords: htlc.CustomRecords.Copy(), + CustomRecords: customRecords, } localACKedIndex := lc.commitChains.Remote.tail().messageIndices.Local @@ -9825,7 +9960,7 @@ func (lc *LightningChannel) unsignedLocalUpdates(remoteMessageIndex, // We don't save add updates as they are restored from the // remote commitment in restoreStateLogs. - if pd.EntryType == Add { + if pd.isAdd() { continue } @@ -9999,3 +10134,23 @@ func (lc *LightningChannel) ZeroConfRealScid() fn.Option[lnwire.ShortChannelID] return fn.None[lnwire.ShortChannelID]() } + +// entryTypeForHtlc returns the add type that should be used for adding this +// HTLC to the channel. If the channel has a tapscript root and the HTLC carries +// the NoOp bit in the custom records then we'll convert this to a NoOp add. +func (lc *LightningChannel) entryTypeForHtlc(records lnwire.CustomRecords, + chanType channeldb.ChannelType) updateType { + + noopTLV := uint64(NoOpHtlcTLVEntry.TypeVal()) + _, noopFlag := records[noopTLV] + if noopFlag && chanType.HasTapscriptRoot() { + return NoOpAdd + } + + if noopFlag && !chanType.HasTapscriptRoot() { + lc.log.Warnf("Received flag for noop-add over a channel that " + + "doesn't have a tapscript root") + } + + return Add +} diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go index 0a0ca261c02..0316dcc54db 100644 --- a/lnwallet/channel_test.go +++ b/lnwallet/channel_test.go @@ -3232,7 +3232,9 @@ func restartChannel(channelOld *LightningChannel) (*LightningChannel, error) { // he receives Alice's CommitSig message, then Alice concludes that she needs // to re-send the CommitDiff. After the diff has been sent, both nodes should // resynchronize and be able to complete the dangling commit. -func testChanSyncOweCommitment(t *testing.T, chanType channeldb.ChannelType) { +func testChanSyncOweCommitment(t *testing.T, + chanType channeldb.ChannelType, noop bool) { + // Create a test channel which will be used for the duration of this // unittest. The channel will be funded evenly with Alice having 5 BTC, // and Bob having 5 BTC. @@ -3242,6 +3244,17 @@ func testChanSyncOweCommitment(t *testing.T, chanType channeldb.ChannelType) { var fakeOnionBlob [lnwire.OnionPacketSize]byte copy(fakeOnionBlob[:], bytes.Repeat([]byte{0x05}, lnwire.OnionPacketSize)) + // Let's create the noop add TLV record. This will only be + // effective for channels that have a tapscript root. + noopRecord := tlv.NewPrimitiveRecord[NoOpHtlcTLVType, bool](true) + records, err := tlv.RecordsToMap([]tlv.Record{noopRecord.Record()}) + require.NoError(t, err) + + // If the noop flag is not set for this test, nullify the records. + if !noop { + records = nil + } + // We'll start off the scenario with Bob sending 3 HTLC's to Alice in a // single state update. htlcAmt := lnwire.NewMSatFromSatoshis(20000) @@ -3251,10 +3264,11 @@ func testChanSyncOweCommitment(t *testing.T, chanType channeldb.ChannelType) { for i := 0; i < 3; i++ { rHash := sha256.Sum256(bobPreimage[:]) h := &lnwire.UpdateAddHTLC{ - PaymentHash: rHash, - Amount: htlcAmt, - Expiry: uint32(10), - OnionBlob: fakeOnionBlob, + PaymentHash: rHash, + Amount: htlcAmt, + Expiry: uint32(10), + OnionBlob: fakeOnionBlob, + CustomRecords: records, } htlcIndex, err := bobChannel.AddHTLC(h, nil) @@ -3290,15 +3304,17 @@ func testChanSyncOweCommitment(t *testing.T, chanType channeldb.ChannelType) { t.Fatalf("unable to settle htlc: %v", err) } } + var alicePreimage [32]byte copy(alicePreimage[:], bytes.Repeat([]byte{0xaa}, 32)) rHash := sha256.Sum256(alicePreimage[:]) aliceHtlc := &lnwire.UpdateAddHTLC{ - ChanID: chanID, - PaymentHash: rHash, - Amount: htlcAmt, - Expiry: uint32(10), - OnionBlob: fakeOnionBlob, + ChanID: chanID, + PaymentHash: rHash, + Amount: htlcAmt, + Expiry: uint32(10), + OnionBlob: fakeOnionBlob, + CustomRecords: records, } aliceHtlcIndex, err := aliceChannel.AddHTLC(aliceHtlc, nil) require.NoError(t, err, "unable to add alice's htlc") @@ -3519,22 +3535,25 @@ func testChanSyncOweCommitment(t *testing.T, chanType channeldb.ChannelType) { // At this point, the final balances of both parties should properly // reflect the amount of HTLC's sent. - bobMsatSent := numBobHtlcs * htlcAmt - if aliceChannel.channelState.TotalMSatSent != htlcAmt { - t.Fatalf("wrong value for msat sent: expected %v, got %v", - htlcAmt, aliceChannel.channelState.TotalMSatSent) - } - if aliceChannel.channelState.TotalMSatReceived != bobMsatSent { - t.Fatalf("wrong value for msat recv: expected %v, got %v", - bobMsatSent, aliceChannel.channelState.TotalMSatReceived) - } - if bobChannel.channelState.TotalMSatSent != bobMsatSent { - t.Fatalf("wrong value for msat sent: expected %v, got %v", - bobMsatSent, bobChannel.channelState.TotalMSatSent) - } - if bobChannel.channelState.TotalMSatReceived != htlcAmt { - t.Fatalf("wrong value for msat recv: expected %v, got %v", - htlcAmt, bobChannel.channelState.TotalMSatReceived) + if noop { + // If this test-case includes noop HTLCs, then we don't expect + // any balance changes. + require.Zero(t, aliceChannel.channelState.TotalMSatSent) + require.Zero(t, aliceChannel.channelState.TotalMSatReceived) + require.Zero(t, bobChannel.channelState.TotalMSatSent) + require.Zero(t, bobChannel.channelState.TotalMSatReceived) + } else { + // Otherwise, calculate the expected changes and assert them. + bobMsatSent := numBobHtlcs * htlcAmt + + aliceChan := aliceChannel.channelState + bobChan := bobChannel.channelState + + require.Equal(t, aliceChan.TotalMSatSent, htlcAmt) + require.Equal(t, aliceChan.TotalMSatReceived, bobMsatSent) + + require.Equal(t, bobChan.TotalMSatSent, bobMsatSent) + require.Equal(t, bobChan.TotalMSatReceived, htlcAmt) } } @@ -3548,6 +3567,7 @@ func TestChanSyncOweCommitment(t *testing.T) { testCases := []struct { name string chanType channeldb.ChannelType + noop bool }{ { name: "tweakless", @@ -3571,10 +3591,18 @@ func TestChanSyncOweCommitment(t *testing.T) { channeldb.SimpleTaprootFeatureBit | channeldb.TapscriptRootBit, }, + { + name: "tapscript root with noop", + chanType: channeldb.SingleFunderTweaklessBit | + channeldb.AnchorOutputsBit | + channeldb.SimpleTaprootFeatureBit | + channeldb.TapscriptRootBit, + noop: true, + }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - testChanSyncOweCommitment(t, tc.chanType) + testChanSyncOweCommitment(t, tc.chanType, tc.noop) }) } } @@ -11339,3 +11367,393 @@ func TestCreateCooperativeCloseTx(t *testing.T) { }) } } + +// TestNoopAddSettle tests that adding and settling an HTLC with no-op, no +// balances are actually affected. +func TestNoopAddSettle(t *testing.T) { + t.Parallel() + + // Create a test channel which will be used for the duration of this + // unittest. The channel will be funded evenly with Alice having 5 BTC, + // and Bob having 5 BTC. + chanType := channeldb.SimpleTaprootFeatureBit | + channeldb.AnchorOutputsBit | channeldb.ZeroHtlcTxFeeBit | + channeldb.SingleFunderTweaklessBit | channeldb.TapscriptRootBit + aliceChannel, bobChannel, err := CreateTestChannels( + t, chanType, + ) + require.NoError(t, err, "unable to create test channels") + + const htlcAmt = 10_000 + htlc, preimage := createHTLC(0, htlcAmt) + noopRecord := tlv.NewPrimitiveRecord[tlv.TlvType65544, bool](true) + + records, err := tlv.RecordsToMap([]tlv.Record{noopRecord.Record()}) + require.NoError(t, err) + htlc.CustomRecords = records + + aliceBalance := aliceChannel.channelState.LocalCommitment.LocalBalance + bobBalance := bobChannel.channelState.LocalCommitment.LocalBalance + + // Have Alice add the HTLC, then lock it in with a new state transition. + aliceHtlcIndex, err := aliceChannel.AddHTLC(htlc, nil) + require.NoError(t, err, "alice unable to add htlc") + bobHtlcIndex, err := bobChannel.ReceiveHTLC(htlc) + require.NoError(t, err, "bob unable to receive htlc") + + err = ForceStateTransition(aliceChannel, bobChannel) + require.NoError(t, err) + + // We'll have Bob settle the HTLC, then force another state transition. + err = bobChannel.SettleHTLC(preimage, bobHtlcIndex, nil, nil, nil) + require.NoError(t, err, "bob unable to settle inbound htlc") + err = aliceChannel.ReceiveHTLCSettle(preimage, aliceHtlcIndex) + require.NoError(t, err) + + err = ForceStateTransition(aliceChannel, bobChannel) + require.NoError(t, err) + + aliceBalanceFinal := aliceChannel.channelState.LocalCommitment.LocalBalance //nolint:ll + bobBalanceFinal := bobChannel.channelState.LocalCommitment.LocalBalance + + // The balances of Alice and Bob should be the exact same and shouldn't + // have changed. + require.Equal(t, aliceBalance, aliceBalanceFinal) + require.Equal(t, bobBalance, bobBalanceFinal) +} + +// TestNoopAddBelowReserve tests that the noop HTLCs behave as expected when +// added over a channel where a party is below their reserve. +func TestNoopAddBelowReserve(t *testing.T) { + t.Parallel() + + // Create a test channel which will be used for the duration of this + // unittest. The channel will be funded evenly with Alice having 5 BTC, + // and Bob having 5 BTC. + chanType := channeldb.SimpleTaprootFeatureBit | + channeldb.AnchorOutputsBit | channeldb.ZeroHtlcTxFeeBit | + channeldb.SingleFunderTweaklessBit | channeldb.TapscriptRootBit + aliceChan, bobChan, err := CreateTestChannels(t, chanType) + require.NoError(t, err, "unable to create test channels") + + aliceBalance := aliceChan.channelState.LocalCommitment.LocalBalance + bobBalance := bobChan.channelState.LocalCommitment.LocalBalance + + const ( + // htlcAmt is the default HTLC amount to be used, epxressed in + // milli-satoshis. + htlcAmt = lnwire.MilliSatoshi(500_000) + + // numHtlc is the total number of HTLCs to be added/settled over + // the channel. + numHtlc = 20 + ) + + // Let's create the noop add TLV record to be used in all added HTLCs + // over the channel. + noopRecord := tlv.NewPrimitiveRecord[NoOpHtlcTLVType, bool](true) + records, err := tlv.RecordsToMap([]tlv.Record{noopRecord.Record()}) + require.NoError(t, err) + + // Let's set Bob's reserve to whatever his local balance is, plus half + // of the total amount to be added by the total HTLCs. This way we can + // also verify that the noop-adds will start the nullification only once + // Bob is above reserve. + reserveTarget := (numHtlc / 2) * htlcAmt + bobReserve := bobBalance + reserveTarget + + bobChan.channelState.LocalChanCfg.ChanReserve = + bobReserve.ToSatoshis() + + aliceChan.channelState.RemoteChanCfg.ChanReserve = + bobReserve.ToSatoshis() + + // Add and settle all the HTLCs over the channel. + for i := range numHtlc { + htlc, preimage := createHTLC(i, htlcAmt) + htlc.CustomRecords = records + + aliceHtlcIndex, err := aliceChan.AddHTLC(htlc, nil) + require.NoError(t, err, "alice unable to add htlc") + bobHtlcIndex, err := bobChan.ReceiveHTLC(htlc) + require.NoError(t, err, "bob unable to receive htlc") + + require.NoError(t, ForceStateTransition(aliceChan, bobChan)) + + // We'll have Bob settle the HTLC, then force another state + // transition. + err = bobChan.SettleHTLC(preimage, bobHtlcIndex, nil, nil, nil) + require.NoError(t, err, "bob unable to settle inbound htlc") + err = aliceChan.ReceiveHTLCSettle(preimage, aliceHtlcIndex) + require.NoError(t, err) + require.NoError(t, ForceStateTransition(aliceChan, bobChan)) + } + + // We need to kick the state transition one last time for the balances + // to be updated on both commitments. + require.NoError(t, ForceStateTransition(aliceChan, bobChan)) + + aliceBalanceFinal := aliceChan.channelState.LocalCommitment.LocalBalance + bobBalanceFinal := bobChan.channelState.LocalCommitment.LocalBalance + + // The balances of Alice and Bob must have changed exactly by half the + // total number of HTLCs we added over the channel, plus one to get Bob + // above the reserve. Bob's final balance should be as much as his + // reserve plus one extra default HTLC amount. + require.Equal(t, aliceBalance-htlcAmt*(numHtlc/2+1), aliceBalanceFinal) + require.Equal(t, bobBalance+htlcAmt*(numHtlc/2+1), bobBalanceFinal) + require.Equal( + t, bobBalanceFinal.ToSatoshis(), + bobChan.LocalChanReserve()+htlcAmt.ToSatoshis(), + ) +} + +// TestEvaluateNoOpHtlc tests that the noop htlc evaluator helper function +// produces the expected balance deltas from various starting states. +func TestEvaluateNoOpHtlc(t *testing.T) { + testCases := []struct { + name string + localBalance, remoteBalance btcutil.Amount + localReserve, remoteReserve btcutil.Amount + entry *paymentDescriptor + receiver lntypes.ChannelParty + balanceDeltas *lntypes.Dual[int64] + expectedDeltas *lntypes.Dual[int64] + }{ + { + name: "local above reserve", + entry: &paymentDescriptor{ + Amount: lnwire.MilliSatoshi(2500), + }, + receiver: lntypes.Local, + balanceDeltas: &lntypes.Dual[int64]{ + Local: 0, + Remote: 0, + }, + expectedDeltas: &lntypes.Dual[int64]{ + Local: 0, + Remote: 2_500, + }, + }, + { + name: "remote above reserve", + entry: &paymentDescriptor{ + Amount: lnwire.MilliSatoshi(2500), + }, + receiver: lntypes.Remote, + balanceDeltas: &lntypes.Dual[int64]{ + Local: 0, + Remote: 0, + }, + expectedDeltas: &lntypes.Dual[int64]{ + Local: 2_500, + Remote: 0, + }, + }, + { + name: "local below reserve", + entry: &paymentDescriptor{ + Amount: lnwire.MilliSatoshi(2500), + }, + receiver: lntypes.Local, + localBalance: 25_000, + localReserve: 50_000, + balanceDeltas: &lntypes.Dual[int64]{ + Local: 0, + Remote: 0, + }, + expectedDeltas: &lntypes.Dual[int64]{ + Local: 2_500, + Remote: 0, + }, + }, + { + name: "remote below reserve", + entry: &paymentDescriptor{ + Amount: lnwire.MilliSatoshi(2500), + }, + receiver: lntypes.Remote, + remoteBalance: 25_000, + remoteReserve: 50_000, + balanceDeltas: &lntypes.Dual[int64]{ + Local: 0, + Remote: 0, + }, + expectedDeltas: &lntypes.Dual[int64]{ + Local: 0, + Remote: 2_500, + }, + }, + + { + name: "local above reserve with delta", + entry: &paymentDescriptor{ + Amount: lnwire.MilliSatoshi(2500), + }, + receiver: lntypes.Local, + localBalance: 25_000, + localReserve: 50_000, + balanceDeltas: &lntypes.Dual[int64]{ + Local: 25_001_000, + Remote: 0, + }, + expectedDeltas: &lntypes.Dual[int64]{ + Local: 25_001_000, + Remote: 2_500, + }, + }, + { + name: "remote above reserve with delta", + entry: &paymentDescriptor{ + Amount: lnwire.MilliSatoshi(2500), + }, + receiver: lntypes.Remote, + remoteBalance: 25_000, + remoteReserve: 50_000, + balanceDeltas: &lntypes.Dual[int64]{ + Local: 0, + Remote: 25_001_000, + }, + expectedDeltas: &lntypes.Dual[int64]{ + Local: 2_500, + Remote: 25_001_000, + }, + }, + { + name: "local below reserve with delta", + entry: &paymentDescriptor{ + Amount: lnwire.MilliSatoshi(2500), + }, + receiver: lntypes.Local, + localBalance: 25_000, + localReserve: 50_000, + balanceDeltas: &lntypes.Dual[int64]{ + Local: 24_999_000, + Remote: 0, + }, + expectedDeltas: &lntypes.Dual[int64]{ + Local: 25_001_500, + Remote: 0, + }, + }, + { + name: "remote below reserve with delta", + entry: &paymentDescriptor{ + Amount: lnwire.MilliSatoshi(2500), + }, + receiver: lntypes.Remote, + remoteBalance: 25_000, + remoteReserve: 50_000, + balanceDeltas: &lntypes.Dual[int64]{ + Local: 0, + Remote: 24_998_000, + }, + expectedDeltas: &lntypes.Dual[int64]{ + Local: 0, + Remote: 25_000_500, + }, + }, + { + name: "local above reserve with negative delta", + entry: &paymentDescriptor{ + Amount: lnwire.MilliSatoshi(2500), + }, + receiver: lntypes.Remote, + localBalance: 55_000, + localReserve: 50_000, + balanceDeltas: &lntypes.Dual[int64]{ + Local: -4_999_000, + Remote: 0, + }, + expectedDeltas: &lntypes.Dual[int64]{ + Local: -4_999_000, + Remote: 2_500, + }, + }, + { + name: "remote above reserve with negative delta", + entry: &paymentDescriptor{ + Amount: lnwire.MilliSatoshi(2500), + }, + receiver: lntypes.Remote, + remoteBalance: 55_000, + remoteReserve: 50_000, + balanceDeltas: &lntypes.Dual[int64]{ + Local: 0, + Remote: -4_999_000, + }, + expectedDeltas: &lntypes.Dual[int64]{ + Local: 2_500, + Remote: -4_999_000, + }, + }, + { + name: "local below reserve with negative delta", + entry: &paymentDescriptor{ + Amount: lnwire.MilliSatoshi(2500), + }, + receiver: lntypes.Local, + localBalance: 55_000, + localReserve: 50_000, + balanceDeltas: &lntypes.Dual[int64]{ + Local: -5_001_000, + Remote: 0, + }, + expectedDeltas: &lntypes.Dual[int64]{ + Local: -4_998_500, + Remote: 0, + }, + }, + { + name: "remote below reserve with negative delta", + entry: &paymentDescriptor{ + Amount: lnwire.MilliSatoshi(2500), + }, + receiver: lntypes.Remote, + remoteBalance: 55_000, + remoteReserve: 50_000, + balanceDeltas: &lntypes.Dual[int64]{ + Local: 0, + Remote: -5_001_000, + }, + expectedDeltas: &lntypes.Dual[int64]{ + Local: 0, + Remote: -4_998_500, + }, + }, + } + + chanType := channeldb.SimpleTaprootFeatureBit | + channeldb.AnchorOutputsBit | channeldb.ZeroHtlcTxFeeBit | + channeldb.SingleFunderTweaklessBit | channeldb.TapscriptRootBit + aliceChan, _, err := CreateTestChannels(t, chanType) + require.NoError(t, err, "unable to create test channels") + + for _, testCase := range testCases { + tc := testCase + + t.Logf("Running test case: %s", testCase.name) + + if tc.localBalance != 0 && tc.localReserve != 0 { + aliceChan.channelState.LocalChanCfg.ChanReserve = + tc.localReserve + + aliceChan.channelState.LocalCommitment.LocalBalance = + lnwire.NewMSatFromSatoshis(tc.localBalance) + } + + if tc.remoteBalance != 0 && tc.remoteReserve != 0 { + aliceChan.channelState.RemoteChanCfg.ChanReserve = + tc.remoteReserve + + aliceChan.channelState.RemoteCommitment.RemoteBalance = + lnwire.NewMSatFromSatoshis(tc.remoteBalance) + } + + aliceChan.evaluateNoOpHtlc( + tc.entry, tc.receiver, tc.balanceDeltas, + ) + + require.Equal(t, tc.expectedDeltas, tc.balanceDeltas) + } +} diff --git a/lnwallet/payment_descriptor.go b/lnwallet/payment_descriptor.go index 49b79a139dc..944749bde9f 100644 --- a/lnwallet/payment_descriptor.go +++ b/lnwallet/payment_descriptor.go @@ -42,6 +42,13 @@ const ( // FeeUpdate is an update type sent by the channel initiator that // updates the fee rate used when signing the commitment transaction. FeeUpdate + + // NoOpAdd is an update type that adds a new HTLC entry into the log. + // This differs from the normal Add type, in that when settled the + // balance may go back to the sender, rather than be credited for the + // receiver. The criteria about whether the balance will go back to the + // sender is whether the receiver is sitting above the channel reserve. + NoOpAdd ) // String returns a human readable string that uniquely identifies the target @@ -58,6 +65,8 @@ func (u updateType) String() string { return "Settle" case FeeUpdate: return "FeeUpdate" + case NoOpAdd: + return "NoOpAdd" default: return "" } @@ -216,6 +225,14 @@ type paymentDescriptor struct { // into the log to the HTLC being modified. EntryType updateType + // noOpSettle is a flag indicating whether a chain of entries resulted + // in an effective no-op settle. That means that the amount was credited + // back to the sender. This is useful as we need a way to mark whether + // the noop add was effective, which can be useful at later stages, + // where we might not be able to re-run the criteria for the + // effectiveness of the noop-add. + noOpSettle bool + // isForwarded denotes if an incoming HTLC has been forwarded to any // possible upstream peers in the route. isForwarded bool @@ -238,7 +255,7 @@ type paymentDescriptor struct { func (pd *paymentDescriptor) toLogUpdate() channeldb.LogUpdate { var msg lnwire.Message switch pd.EntryType { - case Add: + case Add, NoOpAdd: msg = &lnwire.UpdateAddHTLC{ ChanID: pd.ChanID, ID: pd.HtlcIndex, @@ -290,7 +307,7 @@ func (pd *paymentDescriptor) setCommitHeight( whoseCommitChain lntypes.ChannelParty, nextHeight uint64) { switch pd.EntryType { - case Add: + case Add, NoOpAdd: pd.addCommitHeights.SetForParty( whoseCommitChain, nextHeight, ) @@ -311,3 +328,8 @@ func (pd *paymentDescriptor) setCommitHeight( ) } } + +// isAdd returns true if the paymentDescriptor is of type Add. +func (pd *paymentDescriptor) isAdd() bool { + return pd.EntryType == Add || pd.EntryType == NoOpAdd +} diff --git a/make/builder.Dockerfile b/make/builder.Dockerfile index 66e9bc8c422..d2abf4b9669 100644 --- a/make/builder.Dockerfile +++ b/make/builder.Dockerfile @@ -1,6 +1,6 @@ # If you change this please also update GO_VERSION in Makefile (then run # `make lint` to see where else it needs to be updated as well). -FROM golang:1.23.10-bookworm +FROM golang:1.23.12-bookworm MAINTAINER Olaoluwa Osuntokun diff --git a/payments/db/codec.go b/payments/db/codec.go new file mode 100644 index 00000000000..50b357a7f4e --- /dev/null +++ b/payments/db/codec.go @@ -0,0 +1,141 @@ +package paymentsdb + +import ( + "encoding/binary" + "errors" + "io" + "time" + + "github.com/lightningnetwork/lnd/channeldb" +) + +// Big endian is the preferred byte order, due to cursor scans over +// integer keys iterating in order. +var byteOrder = binary.BigEndian + +// UnknownElementType is an alias for channeldb.UnknownElementType. +type UnknownElementType = channeldb.UnknownElementType + +// ReadElement deserializes a single element from the provided io.Reader. +func ReadElement(r io.Reader, element interface{}) error { + err := channeldb.ReadElement(r, element) + switch { + // Known to channeldb codec. + case err == nil: + return nil + + // Fail if error is not UnknownElementType. + default: + var unknownElementType UnknownElementType + if !errors.As(err, &unknownElementType) { + return err + } + } + + // Process any wtdb-specific extensions to the codec. + switch e := element.(type) { + case *paymentIndexType: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + // Type is still unknown to wtdb extensions, fail. + default: + return channeldb.NewUnknownElementType( + "ReadElement", element, + ) + } + + return nil +} + +// WriteElement serializes a single element into the provided io.Writer. +func WriteElement(w io.Writer, element interface{}) error { + err := channeldb.WriteElement(w, element) + switch { + // Known to channeldb codec. + case err == nil: + return nil + + // Fail if error is not UnknownElementType. + default: + var unknownElementType UnknownElementType + if !errors.As(err, &unknownElementType) { + return err + } + } + + // Process any wtdb-specific extensions to the codec. + switch e := element.(type) { + case paymentIndexType: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + // Type is still unknown to wtdb extensions, fail. + default: + return channeldb.NewUnknownElementType( + "WriteElement", element, + ) + } + + return nil +} + +// WriteElements serializes a variadic list of elements into the given +// io.Writer. +func WriteElements(w io.Writer, elements ...interface{}) error { + for _, element := range elements { + if err := WriteElement(w, element); err != nil { + return err + } + } + + return nil +} + +// ReadElements deserializes the provided io.Reader into a variadic list of +// target elements. +func ReadElements(r io.Reader, elements ...interface{}) error { + for _, element := range elements { + if err := ReadElement(r, element); err != nil { + return err + } + } + + return nil +} + +// deserializeTime deserializes time as unix nanoseconds. +func deserializeTime(r io.Reader) (time.Time, error) { + var scratch [8]byte + if _, err := io.ReadFull(r, scratch[:]); err != nil { + return time.Time{}, err + } + + // Convert to time.Time. Interpret unix nano time zero as a zero + // time.Time value. + unixNano := byteOrder.Uint64(scratch[:]) + if unixNano == 0 { + return time.Time{}, nil + } + + return time.Unix(0, int64(unixNano)), nil +} + +// serializeTime serializes time as unix nanoseconds. +func serializeTime(w io.Writer, t time.Time) error { + var scratch [8]byte + + // Convert to unix nano seconds, but only if time is non-zero. Calling + // UnixNano() on a zero time yields an undefined result. + var unixNano int64 + if !t.IsZero() { + unixNano = t.UnixNano() + } + + byteOrder.PutUint64(scratch[:], uint64(unixNano)) + _, err := w.Write(scratch[:]) + + return err +} diff --git a/payments/db/interface.go b/payments/db/interface.go new file mode 100644 index 00000000000..03d80423760 --- /dev/null +++ b/payments/db/interface.go @@ -0,0 +1,75 @@ +package paymentsdb + +import ( + "context" + + "github.com/lightningnetwork/lnd/lntypes" +) + +// PaymentDB is the interface that represents the underlying payments database. +type PaymentDB interface { + PaymentReader + PaymentWriter +} + +// PaymentReader is the interface that reads from the payments database. +type PaymentReader interface { + // QueryPayments queries the payments database and should support + // pagination. + QueryPayments(ctx context.Context, query Query) (Response, error) + + // FetchPayment fetches the payment corresponding to the given payment + // hash. + FetchPayment(paymentHash lntypes.Hash) (*MPPayment, error) + + // FetchInFlightPayments returns all payments with status InFlight. + FetchInFlightPayments() ([]*MPPayment, error) +} + +// PaymentWriter is the interface that writes to the payments database. +type PaymentWriter interface { + // DeletePayment deletes a payment from the DB given its payment hash. + DeletePayment(paymentHash lntypes.Hash, failedAttemptsOnly bool) error + + // DeletePayments deletes all payments from the DB given the specified + // flags. + DeletePayments(failedOnly, failedAttemptsOnly bool) (int, error) + + // DeleteFailedAttempts removes all failed HTLCs from the db. It should + // be called for a given payment whenever all inflight htlcs are + // completed, and the payment has reached a final settled state. + DeleteFailedAttempts(lntypes.Hash) error + + PaymentControl +} + +// PaymentControl is the interface that controls the payment lifecycle. +type PaymentControl interface { + // This method checks that no succeeded payment exist for this payment + // hash. + InitPayment(lntypes.Hash, *PaymentCreationInfo) error + + // RegisterAttempt atomically records the provided HTLCAttemptInfo. + RegisterAttempt(lntypes.Hash, *HTLCAttemptInfo) (*MPPayment, error) + + // SettleAttempt marks the given attempt settled with the preimage. If + // this is a multi shard payment, this might implicitly mean the + // full payment succeeded. + // + // After invoking this method, InitPayment should always return an + // error to prevent us from making duplicate payments to the same + // payment hash. The provided preimage is atomically saved to the DB + // for record keeping. + SettleAttempt(lntypes.Hash, uint64, *HTLCSettleInfo) (*MPPayment, error) + + // FailAttempt marks the given payment attempt failed. + FailAttempt(lntypes.Hash, uint64, *HTLCFailInfo) (*MPPayment, error) + + // Fail transitions a payment into the Failed state, and records + // the ultimate reason the payment failed. Note that this should only + // be called when all active attempts are already failed. After + // invoking this method, InitPayment should return nil on its next call + // for this payment hash, allowing the user to make a subsequent + // payment. + Fail(lntypes.Hash, FailureReason) (*MPPayment, error) +} diff --git a/channeldb/duplicate_payments.go b/payments/db/kv_duplicate_payments.go similarity index 98% rename from channeldb/duplicate_payments.go rename to payments/db/kv_duplicate_payments.go index 004722f0067..3ac1faddd60 100644 --- a/channeldb/duplicate_payments.go +++ b/payments/db/kv_duplicate_payments.go @@ -1,4 +1,4 @@ -package channeldb +package paymentsdb import ( "bytes" @@ -11,7 +11,6 @@ import ( "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" - paymentsdb "github.com/lightningnetwork/lnd/payments/db" "github.com/lightningnetwork/lnd/routing/route" ) @@ -75,7 +74,7 @@ func fetchDuplicatePaymentStatus(bucket kvdb.RBucket) (PaymentStatus, error) { return StatusInFlight, nil } - return 0, paymentsdb.ErrPaymentNotInitiated + return 0, ErrPaymentNotInitiated } func deserializeDuplicateHTLCAttemptInfo(r io.Reader) ( diff --git a/channeldb/payments_kv_store.go b/payments/db/kv_store.go similarity index 97% rename from channeldb/payments_kv_store.go rename to payments/db/kv_store.go index cf2ca918b98..afc7dec0ac0 100644 --- a/channeldb/payments_kv_store.go +++ b/payments/db/kv_store.go @@ -1,4 +1,4 @@ -package channeldb +package paymentsdb import ( "bytes" @@ -13,10 +13,10 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" - paymentsdb "github.com/lightningnetwork/lnd/payments/db" "github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/tlv" @@ -127,19 +127,21 @@ type KVPaymentsDB struct { // db is the underlying database implementation. db kvdb.Backend + // keepFailedPaymentAttempts is a flag that indicates whether we should + // keep failed payment attempts in the database. keepFailedPaymentAttempts bool } // defaultKVStoreOptions returns the default options for the KV store. -func defaultKVStoreOptions() *paymentsdb.StoreOptions { - return &paymentsdb.StoreOptions{ +func defaultKVStoreOptions() *StoreOptions { + return &StoreOptions{ KeepFailedPaymentAttempts: false, } } // NewKVPaymentsDB creates a new KVStore for payments. func NewKVPaymentsDB(db kvdb.Backend, - options ...paymentsdb.OptionModifier) (*KVPaymentsDB, error) { + options ...OptionModifier) (*KVPaymentsDB, error) { opts := defaultKVStoreOptions() for _, applyOption := range options { @@ -158,6 +160,8 @@ func NewKVPaymentsDB(db kvdb.Backend, }, nil } +// paymentsTopLevelBuckets is a list of top-level buckets that are used for +// the payments database when using the kv store. var paymentsTopLevelBuckets = [][]byte{ paymentsRootBucket, paymentsIndexBucket, @@ -229,7 +233,7 @@ func (p *KVPaymentsDB) InitPayment(paymentHash lntypes.Hash, // Otherwise, if the error is not `ErrPaymentNotInitiated`, // we'll return the error. - case !errors.Is(err, paymentsdb.ErrPaymentNotInitiated): + case !errors.Is(err, ErrPaymentNotInitiated): return err } @@ -403,7 +407,7 @@ func (p *KVPaymentsDB) RegisterAttempt(paymentHash lntypes.Hash, // MPP records should not be set for attempts to blinded paths. if isBlinded && mpp != nil { - return paymentsdb.ErrMPPRecordInBlindedPayment + return ErrMPPRecordInBlindedPayment } for _, h := range payment.InFlightHTLCs() { @@ -412,7 +416,7 @@ func (p *KVPaymentsDB) RegisterAttempt(paymentHash lntypes.Hash, // If this is a blinded payment, then no existing HTLCs // should have MPP records. if isBlinded && hMpp != nil { - return paymentsdb.ErrMPPRecordInBlindedPayment + return ErrMPPRecordInBlindedPayment } // If this is a blinded payment, then we just need to @@ -424,7 +428,7 @@ func (p *KVPaymentsDB) RegisterAttempt(paymentHash lntypes.Hash, h.Route.FinalHop().TotalAmtMsat { //nolint:ll - return paymentsdb.ErrBlindedPaymentTotalAmountMismatch + return ErrBlindedPaymentTotalAmountMismatch } continue @@ -434,12 +438,12 @@ func (p *KVPaymentsDB) RegisterAttempt(paymentHash lntypes.Hash, // We tried to register a non-MPP attempt for a MPP // payment. case mpp == nil && hMpp != nil: - return paymentsdb.ErrMPPayment + return ErrMPPayment // We tried to register a MPP shard for a non-MPP // payment. case mpp != nil && hMpp == nil: - return paymentsdb.ErrNonMPPayment + return ErrNonMPPayment // Non-MPP payment, nothing more to validate. case mpp == nil: @@ -448,11 +452,11 @@ func (p *KVPaymentsDB) RegisterAttempt(paymentHash lntypes.Hash, // Check that MPP options match. if mpp.PaymentAddr() != hMpp.PaymentAddr() { - return paymentsdb.ErrMPPPaymentAddrMismatch + return ErrMPPPaymentAddrMismatch } if mpp.TotalMsat() != hMpp.TotalMsat() { - return paymentsdb.ErrMPPTotalAmountMismatch + return ErrMPPTotalAmountMismatch } } @@ -461,14 +465,14 @@ func (p *KVPaymentsDB) RegisterAttempt(paymentHash lntypes.Hash, // attempt. amt := attempt.Route.ReceiverAmt() if !isBlinded && mpp == nil && amt != payment.Info.Value { - return paymentsdb.ErrValueMismatch + return ErrValueMismatch } // Ensure we aren't sending more than the total payment amount. sentAmt, _ := payment.SentAmt() if sentAmt+amt > payment.Info.Value { return fmt.Errorf("%w: attempted=%v, payment amount="+ - "%v", paymentsdb.ErrValueExceedsAmt, + "%v", ErrValueExceedsAmt, sentAmt+amt, payment.Info.Value) } @@ -574,12 +578,12 @@ func (p *KVPaymentsDB) updateHtlcKey(paymentHash lntypes.Hash, // Make sure the shard is not already failed or settled. failKey := htlcBucketKey(htlcFailInfoKey, aid) if htlcsBucket.Get(failKey) != nil { - return paymentsdb.ErrAttemptAlreadyFailed + return ErrAttemptAlreadyFailed } settleKey := htlcBucketKey(htlcSettleInfoKey, aid) if htlcsBucket.Get(settleKey) != nil { - return paymentsdb.ErrAttemptAlreadySettled + return ErrAttemptAlreadySettled } // Add or update the key for this htlc. @@ -619,8 +623,8 @@ func (p *KVPaymentsDB) Fail(paymentHash lntypes.Hash, prefetchPayment(tx, paymentHash) bucket, err := fetchPaymentBucketUpdate(tx, paymentHash) - if errors.Is(err, paymentsdb.ErrPaymentNotInitiated) { - updateErr = paymentsdb.ErrPaymentNotInitiated + if errors.Is(err, ErrPaymentNotInitiated) { + updateErr = ErrPaymentNotInitiated return nil } else if err != nil { return err @@ -631,8 +635,8 @@ func (p *KVPaymentsDB) Fail(paymentHash lntypes.Hash, // failure to the KVPaymentsDB without synchronizing with // other attempts. _, err = fetchPaymentStatus(bucket) - if errors.Is(err, paymentsdb.ErrPaymentNotInitiated) { - updateErr = paymentsdb.ErrPaymentNotInitiated + if errors.Is(err, ErrPaymentNotInitiated) { + updateErr = ErrPaymentNotInitiated return nil } else if err != nil { return err @@ -725,12 +729,12 @@ func fetchPaymentBucket(tx kvdb.RTx, paymentHash lntypes.Hash) ( payments := tx.ReadBucket(paymentsRootBucket) if payments == nil { - return nil, paymentsdb.ErrPaymentNotInitiated + return nil, ErrPaymentNotInitiated } bucket := payments.NestedReadBucket(paymentHash[:]) if bucket == nil { - return nil, paymentsdb.ErrPaymentNotInitiated + return nil, ErrPaymentNotInitiated } return bucket, nil @@ -743,12 +747,12 @@ func fetchPaymentBucketUpdate(tx kvdb.RwTx, paymentHash lntypes.Hash) ( payments := tx.ReadWriteBucket(paymentsRootBucket) if payments == nil { - return nil, paymentsdb.ErrPaymentNotInitiated + return nil, ErrPaymentNotInitiated } bucket := payments.NestedReadWriteBucket(paymentHash[:]) if bucket == nil { - return nil, paymentsdb.ErrPaymentNotInitiated + return nil, ErrPaymentNotInitiated } return bucket, nil @@ -805,7 +809,7 @@ func fetchPaymentStatus(bucket kvdb.RBucket) (PaymentStatus, error) { // Creation info should be set for all payments, regardless of state. // If not, it is unknown. if bucket.Get(paymentCreationInfoKey) == nil { - return 0, paymentsdb.ErrPaymentNotInitiated + return 0, ErrPaymentNotInitiated } payment, err := fetchPayment(bucket) @@ -1051,7 +1055,7 @@ func fetchHtlcAttempts(bucket kvdb.RBucket) ([]HTLCAttempt, error) { // Sanity check that all htlcs have an attempt info. if attemptInfoCount != len(htlcsMap) { - return nil, paymentsdb.ErrNoAttemptInfo + return nil, ErrNoAttemptInfo } keys := make([]uint64, len(htlcsMap)) @@ -1131,9 +1135,9 @@ func fetchFailedHtlcKeys(bucket kvdb.RBucket) ([][]byte, error) { // to a subset of payments by the payments query, containing an offset // index and a maximum number of returned payments. func (p *KVPaymentsDB) QueryPayments(_ context.Context, - query PaymentsQuery) (PaymentsResponse, error) { + query Query) (Response, error) { - var resp PaymentsResponse + var resp Response if err := kvdb.View(p.db, func(tx kvdb.RTx) error { // Get the root payments bucket. @@ -1205,13 +1209,13 @@ func (p *KVPaymentsDB) QueryPayments(_ context.Context, // Create a paginator which reads from our sequence index bucket // with the parameters provided by the payments query. - paginator := newPaginator( + paginator := channeldb.NewPaginator( indexes.ReadCursor(), query.Reversed, query.IndexOffset, query.MaxPayments, ) // Run a paginated query, adding payments to our response. - if err := paginator.query(accumulatePayments); err != nil { + if err := paginator.Query(accumulatePayments); err != nil { return err } @@ -1247,7 +1251,7 @@ func (p *KVPaymentsDB) QueryPayments(_ context.Context, return nil }, func() { - resp = PaymentsResponse{} + resp = Response{} }); err != nil { return resp, err } @@ -1290,7 +1294,7 @@ func fetchPaymentWithSequenceNumber(tx kvdb.RTx, paymentHash lntypes.Hash, // the payment we are actually looking for. seqBytes := bucket.Get(paymentSequenceKey) if seqBytes == nil { - return nil, paymentsdb.ErrNoSequenceNumber + return nil, ErrNoSequenceNumber } // If this top level payment has the sequence number we are looking for, @@ -1305,7 +1309,7 @@ func fetchPaymentWithSequenceNumber(tx kvdb.RTx, paymentHash lntypes.Hash, // find a duplicate payments bucket here, something is wrong. dup := bucket.NestedReadBucket(duplicatePaymentsBucket) if dup == nil { - return nil, paymentsdb.ErrNoDuplicateBucket + return nil, ErrNoDuplicateBucket } var duplicatePayment *MPPayment @@ -1313,7 +1317,7 @@ func fetchPaymentWithSequenceNumber(tx kvdb.RTx, paymentHash lntypes.Hash, subBucket := dup.NestedReadBucket(k) if subBucket == nil { // We one bucket for each duplicate to be found. - return paymentsdb.ErrNoDuplicateNestedBucket + return ErrNoDuplicateNestedBucket } seqBytes := subBucket.Get(duplicatePaymentSequenceKey) @@ -1342,7 +1346,7 @@ func fetchPaymentWithSequenceNumber(tx kvdb.RTx, paymentHash lntypes.Hash, // failed to find the payment with this sequence number; something is // wrong. if duplicatePayment == nil { - return nil, paymentsdb.ErrDuplicateNotFound + return nil, ErrDuplicateNotFound } return duplicatePayment, nil diff --git a/payments/db/kv_store_test.go b/payments/db/kv_store_test.go new file mode 100644 index 00000000000..f02dbccd6d3 --- /dev/null +++ b/payments/db/kv_store_test.go @@ -0,0 +1,1065 @@ +package paymentsdb + +import ( + "bytes" + "context" + "math" + "reflect" + "testing" + "time" + + "github.com/btcsuite/btcwallet/walletdb" + "github.com/lightningnetwork/lnd/kvdb" + "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/routing/route" + "github.com/lightningnetwork/lnd/tlv" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestKVPaymentsDBDeleteNonInFlight checks that calling DeletePayments only +// deletes payments from the database that are not in-flight. +func TestKVPaymentsDBDeleteNonInFlight(t *testing.T) { + t.Parallel() + + paymentDB := NewKVTestDB(t) + + // Create a sequence number for duplicate payments that will not collide + // with the sequence numbers for the payments we create. These values + // start at 1, so 9999 is a safe bet for this test. + var duplicateSeqNr = 9999 + + payments := []struct { + failed bool + success bool + hasDuplicate bool + }{ + { + failed: true, + success: false, + hasDuplicate: false, + }, + { + failed: false, + success: true, + hasDuplicate: false, + }, + { + failed: false, + success: false, + hasDuplicate: false, + }, + { + failed: false, + success: true, + hasDuplicate: true, + }, + } + + var numSuccess, numInflight int + + for _, p := range payments { + info, attempt, preimg, err := genInfo(t) + if err != nil { + t.Fatalf("unable to generate htlc message: %v", err) + } + + // Sends base htlc message which initiate StatusInFlight. + err = paymentDB.InitPayment(info.PaymentIdentifier, info) + if err != nil { + t.Fatalf("unable to send htlc message: %v", err) + } + _, err = paymentDB.RegisterAttempt( + info.PaymentIdentifier, attempt, + ) + if err != nil { + t.Fatalf("unable to send htlc message: %v", err) + } + + htlc := &htlcStatus{ + HTLCAttemptInfo: attempt, + } + + switch { + case p.failed: + // Fail the payment attempt. + htlcFailure := HTLCFailUnreadable + _, err := paymentDB.FailAttempt( + info.PaymentIdentifier, attempt.AttemptID, + &HTLCFailInfo{ + Reason: htlcFailure, + }, + ) + if err != nil { + t.Fatalf("unable to fail htlc: %v", err) + } + + // Fail the payment, which should moved it to Failed. + failReason := FailureReasonNoRoute + _, err = paymentDB.Fail( + info.PaymentIdentifier, failReason, + ) + if err != nil { + t.Fatalf("unable to fail payment hash: %v", err) + } + + // Verify the status is indeed Failed. + assertPaymentStatus( + t, paymentDB, info.PaymentIdentifier, + StatusFailed, + ) + + htlc.failure = &htlcFailure + assertPaymentInfo( + t, paymentDB, info.PaymentIdentifier, info, + &failReason, htlc, + ) + + case p.success: + // Verifies that status was changed to StatusSucceeded. + _, err := paymentDB.SettleAttempt( + info.PaymentIdentifier, attempt.AttemptID, + &HTLCSettleInfo{ + Preimage: preimg, + }, + ) + if err != nil { + t.Fatalf("error shouldn't have been received,"+ + " got: %v", err) + } + + assertPaymentStatus( + t, paymentDB, info.PaymentIdentifier, + StatusSucceeded, + ) + + htlc.settle = &preimg + assertPaymentInfo( + t, paymentDB, info.PaymentIdentifier, info, nil, + htlc, + ) + + numSuccess++ + + default: + assertPaymentStatus( + t, paymentDB, info.PaymentIdentifier, + StatusInFlight, + ) + assertPaymentInfo( + t, paymentDB, info.PaymentIdentifier, info, nil, + htlc, + ) + + numInflight++ + } + + // If the payment is intended to have a duplicate payment, we + // add one. + if p.hasDuplicate { + appendDuplicatePayment( + t, paymentDB.db, info.PaymentIdentifier, + uint64(duplicateSeqNr), preimg, + ) + duplicateSeqNr++ + numSuccess++ + } + } + + // Delete all failed payments. + numPayments, err := paymentDB.DeletePayments(true, false) + require.NoError(t, err) + require.EqualValues(t, 1, numPayments) + + // This should leave the succeeded and in-flight payments. + dbPayments, err := paymentDB.FetchPayments() + if err != nil { + t.Fatal(err) + } + + if len(dbPayments) != numSuccess+numInflight { + t.Fatalf("expected %d payments, got %d", + numSuccess+numInflight, len(dbPayments)) + } + + var s, i int + for _, p := range dbPayments { + t.Log("fetch payment has status", p.Status) + switch p.Status { + case StatusSucceeded: + s++ + case StatusInFlight: + i++ + } + } + + if s != numSuccess { + t.Fatalf("expected %d succeeded payments , got %d", + numSuccess, s) + } + if i != numInflight { + t.Fatalf("expected %d in-flight payments, got %d", + numInflight, i) + } + + // Now delete all payments except in-flight. + numPayments, err = paymentDB.DeletePayments(false, false) + require.NoError(t, err) + require.EqualValues(t, 2, numPayments) + + // This should leave the in-flight payment. + dbPayments, err = paymentDB.FetchPayments() + if err != nil { + t.Fatal(err) + } + + if len(dbPayments) != numInflight { + t.Fatalf("expected %d payments, got %d", numInflight, + len(dbPayments)) + } + + for _, p := range dbPayments { + if p.Status != StatusInFlight { + t.Fatalf("expected in-fligth status, got %v", p.Status) + } + } + + // Finally, check that we only have a single index left in the payment + // index bucket. + var indexCount int + err = kvdb.View(paymentDB.db, func(tx walletdb.ReadTx) error { + index := tx.ReadBucket(paymentsIndexBucket) + + return index.ForEach(func(k, v []byte) error { + indexCount++ + return nil + }) + }, func() { indexCount = 0 }) + require.NoError(t, err) + + require.Equal(t, 1, indexCount) +} + +type htlcStatus struct { + *HTLCAttemptInfo + settle *lntypes.Preimage + failure *HTLCFailReason +} + +// fetchPaymentIndexEntry gets the payment hash for the sequence number provided +// from our payment indexes bucket. +func fetchPaymentIndexEntry(_ *testing.T, p *KVPaymentsDB, + sequenceNumber uint64) (*lntypes.Hash, error) { + + var hash lntypes.Hash + + if err := kvdb.View(p.db, func(tx walletdb.ReadTx) error { + indexBucket := tx.ReadBucket(paymentsIndexBucket) + key := make([]byte, 8) + byteOrder.PutUint64(key, sequenceNumber) + + indexValue := indexBucket.Get(key) + if indexValue == nil { + return ErrNoSequenceNrIndex + } + + r := bytes.NewReader(indexValue) + + var err error + hash, err = deserializePaymentIndex(r) + + return err + }, func() { + hash = lntypes.Hash{} + }); err != nil { + return nil, err + } + + return &hash, nil +} + +// assertPaymentIndex looks up the index for a payment in the db and checks +// that its payment hash matches the expected hash passed in. +func assertPaymentIndex(t *testing.T, p PaymentDB, expectedHash lntypes.Hash) { + t.Helper() + + // Only the kv implementation uses the index so we exit early if the + // payment db is not a kv implementation. This helps us to reuse the + // same test for both implementations. + kvPaymentDB, ok := p.(*KVPaymentsDB) + if !ok { + return + } + + // Lookup the payment so that we have its sequence number and check + // that is has correctly been indexed in the payment indexes bucket. + pmt, err := kvPaymentDB.FetchPayment(expectedHash) + require.NoError(t, err) + + hash, err := fetchPaymentIndexEntry(t, kvPaymentDB, pmt.SequenceNum) + require.NoError(t, err) + assert.Equal(t, expectedHash, *hash) +} + +// assertNoIndex checks that an index for the sequence number provided does not +// exist. +func assertNoIndex(t *testing.T, p PaymentDB, seqNr uint64) { + t.Helper() + + kvPaymentDB, ok := p.(*KVPaymentsDB) + if !ok { + return + } + + _, err := fetchPaymentIndexEntry(t, kvPaymentDB, seqNr) + require.Equal(t, ErrNoSequenceNrIndex, err) +} + +func makeFakeInfo(t *testing.T) (*PaymentCreationInfo, + *HTLCAttemptInfo) { + + var preimg lntypes.Preimage + copy(preimg[:], rev[:]) + + hash := preimg.Hash() + + c := &PaymentCreationInfo{ + PaymentIdentifier: hash, + Value: 1000, + // Use single second precision to avoid false positive test + // failures due to the monotonic time component. + CreationTime: time.Unix(time.Now().Unix(), 0), + PaymentRequest: []byte("test"), + } + + a, err := NewHtlcAttempt( + 44, priv, testRoute, time.Unix(100, 0), &hash, + ) + require.NoError(t, err) + + return c, &a.HTLCAttemptInfo +} + +func TestSentPaymentSerialization(t *testing.T) { + t.Parallel() + + c, s := makeFakeInfo(t) + + var b bytes.Buffer + require.NoError(t, serializePaymentCreationInfo(&b, c), "serialize") + + // Assert the length of the serialized creation info is as expected, + // without any custom records. + baseLength := 32 + 8 + 8 + 4 + len(c.PaymentRequest) + require.Len(t, b.Bytes(), baseLength) + + newCreationInfo, err := deserializePaymentCreationInfo(&b) + require.NoError(t, err, "deserialize") + require.Equal(t, c, newCreationInfo) + + b.Reset() + + // Now we add some custom records to the creation info and serialize it + // again. + c.FirstHopCustomRecords = lnwire.CustomRecords{ + lnwire.MinCustomRecordsTlvType: []byte{1, 2, 3}, + } + require.NoError(t, serializePaymentCreationInfo(&b, c), "serialize") + + newCreationInfo, err = deserializePaymentCreationInfo(&b) + require.NoError(t, err, "deserialize") + require.Equal(t, c, newCreationInfo) + + b.Reset() + require.NoError(t, serializeHTLCAttemptInfo(&b, s), "serialize") + + newWireInfo, err := deserializeHTLCAttemptInfo(&b) + require.NoError(t, err, "deserialize") + + // First we verify all the records match up properly. + require.Equal(t, s.Route, newWireInfo.Route) + + // We now add the new fields and custom records to the route and + // serialize it again. + b.Reset() + s.Route.FirstHopAmount = tlv.NewRecordT[tlv.TlvType0]( + tlv.NewBigSizeT(lnwire.MilliSatoshi(1234)), + ) + s.Route.FirstHopWireCustomRecords = lnwire.CustomRecords{ + lnwire.MinCustomRecordsTlvType + 3: []byte{4, 5, 6}, + } + require.NoError(t, serializeHTLCAttemptInfo(&b, s), "serialize") + + newWireInfo, err = deserializeHTLCAttemptInfo(&b) + require.NoError(t, err, "deserialize") + require.Equal(t, s.Route, newWireInfo.Route) + + err = newWireInfo.attachOnionBlobAndCircuit() + require.NoError(t, err) + + // Clear routes to allow DeepEqual to compare the remaining fields. + newWireInfo.Route = route.Route{} + s.Route = route.Route{} + newWireInfo.AttemptID = s.AttemptID + + // Call session key method to set our cached session key so we can use + // DeepEqual, and assert that our key equals the original key. + require.Equal(t, s.cachedSessionKey, newWireInfo.SessionKey()) + + require.Equal(t, s, newWireInfo) +} + +// TestRouteSerialization tests serialization of a regular and blinded route. +func TestRouteSerialization(t *testing.T) { + t.Parallel() + + testSerializeRoute(t, testRoute) + testSerializeRoute(t, testBlindedRoute) +} + +func testSerializeRoute(t *testing.T, route route.Route) { + var b bytes.Buffer + err := SerializeRoute(&b, route) + require.NoError(t, err) + + r := bytes.NewReader(b.Bytes()) + route2, err := DeserializeRoute(r) + require.NoError(t, err) + + reflect.DeepEqual(route, route2) +} + +// deletePayment removes a payment with paymentHash from the payments database. +func deletePayment(t *testing.T, db kvdb.Backend, paymentHash lntypes.Hash, + seqNr uint64) { + + t.Helper() + + err := kvdb.Update(db, func(tx kvdb.RwTx) error { + payments := tx.ReadWriteBucket(paymentsRootBucket) + + // Delete the payment bucket. + err := payments.DeleteNestedBucket(paymentHash[:]) + if err != nil { + return err + } + + key := make([]byte, 8) + byteOrder.PutUint64(key, seqNr) + + // Delete the index that references this payment. + indexes := tx.ReadWriteBucket(paymentsIndexBucket) + + return indexes.Delete(key) + }, func() {}) + + if err != nil { + t.Fatalf("could not delete "+ + "payment: %v", err) + } +} + +// TestFetchPaymentWithSequenceNumber tests lookup of payments with their +// sequence number. It sets up one payment with no duplicates, and another with +// two duplicates in its duplicates bucket then uses these payments to test the +// case where a specific duplicate is not found and the duplicates bucket is not +// present when we expect it to be. +func TestFetchPaymentWithSequenceNumber(t *testing.T) { + paymentDB := NewKVTestDB(t) + + // Generate a test payment which does not have duplicates. + noDuplicates, _, _, err := genInfo(t) + require.NoError(t, err) + + // Create a new payment entry in the database. + err = paymentDB.InitPayment( + noDuplicates.PaymentIdentifier, noDuplicates, + ) + require.NoError(t, err) + + // Fetch the payment so we can get its sequence nr. + noDuplicatesPayment, err := paymentDB.FetchPayment( + noDuplicates.PaymentIdentifier, + ) + require.NoError(t, err) + + // Generate a test payment which we will add duplicates to. + hasDuplicates, _, preimg, err := genInfo(t) + require.NoError(t, err) + + // Create a new payment entry in the database. + err = paymentDB.InitPayment( + hasDuplicates.PaymentIdentifier, hasDuplicates, + ) + require.NoError(t, err) + + // Fetch the payment so we can get its sequence nr. + hasDuplicatesPayment, err := paymentDB.FetchPayment( + hasDuplicates.PaymentIdentifier, + ) + require.NoError(t, err) + + // We declare the sequence numbers used here so that we can reference + // them in tests. + var ( + duplicateOneSeqNr = hasDuplicatesPayment.SequenceNum + 1 + duplicateTwoSeqNr = hasDuplicatesPayment.SequenceNum + 2 + ) + + // Add two duplicates to our second payment. + appendDuplicatePayment( + t, paymentDB.db, hasDuplicates.PaymentIdentifier, + duplicateOneSeqNr, preimg, + ) + appendDuplicatePayment( + t, paymentDB.db, hasDuplicates.PaymentIdentifier, + duplicateTwoSeqNr, preimg, + ) + + tests := []struct { + name string + paymentHash lntypes.Hash + sequenceNumber uint64 + expectedErr error + }{ + { + name: "lookup payment without duplicates", + paymentHash: noDuplicates.PaymentIdentifier, + sequenceNumber: noDuplicatesPayment.SequenceNum, + expectedErr: nil, + }, + { + name: "lookup payment with duplicates", + paymentHash: hasDuplicates.PaymentIdentifier, + sequenceNumber: hasDuplicatesPayment.SequenceNum, + expectedErr: nil, + }, + { + name: "lookup first duplicate", + paymentHash: hasDuplicates.PaymentIdentifier, + sequenceNumber: duplicateOneSeqNr, + expectedErr: nil, + }, + { + name: "lookup second duplicate", + paymentHash: hasDuplicates.PaymentIdentifier, + sequenceNumber: duplicateTwoSeqNr, + expectedErr: nil, + }, + { + name: "lookup non-existent duplicate", + paymentHash: hasDuplicates.PaymentIdentifier, + sequenceNumber: 999999, + expectedErr: ErrDuplicateNotFound, + }, + { + name: "lookup duplicate, no duplicates " + + "bucket", + paymentHash: noDuplicates.PaymentIdentifier, + sequenceNumber: duplicateTwoSeqNr, + expectedErr: ErrNoDuplicateBucket, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + //nolint:ll + err := kvdb.Update( + paymentDB.db, func(tx walletdb.ReadWriteTx) error { + var seqNrBytes [8]byte + byteOrder.PutUint64( + seqNrBytes[:], + test.sequenceNumber, + ) + + //nolint:ll + _, err := fetchPaymentWithSequenceNumber( + tx, test.paymentHash, seqNrBytes[:], + ) + + return err + }, func() {}, + ) + require.Equal(t, test.expectedErr, err) + }) + } +} + +// appendDuplicatePayment adds a duplicate payment to an existing payment. Note +// that this function requires a unique sequence number. +// +// This code is *only* intended to replicate legacy duplicate payments in lnd, +// our current schema does not allow duplicates. +func appendDuplicatePayment(t *testing.T, db kvdb.Backend, + paymentHash lntypes.Hash, seqNr uint64, preImg lntypes.Preimage) { + + err := kvdb.Update(db, func(tx walletdb.ReadWriteTx) error { + bucket, err := fetchPaymentBucketUpdate( + tx, paymentHash, + ) + if err != nil { + return err + } + + // Create the duplicates bucket if it is not + // present. + dup, err := bucket.CreateBucketIfNotExists( + duplicatePaymentsBucket, + ) + if err != nil { + return err + } + + var sequenceKey [8]byte + byteOrder.PutUint64(sequenceKey[:], seqNr) + + // Create duplicate payments for the two dup + // sequence numbers we've setup. + putDuplicatePayment(t, dup, sequenceKey[:], paymentHash, preImg) + + // Finally, once we have created our entry we add an index for + // it. + err = createPaymentIndexEntry(tx, sequenceKey[:], paymentHash) + require.NoError(t, err) + + return nil + }, func() {}) + require.NoError(t, err, "could not create payment") +} + +// putDuplicatePayment creates a duplicate payment in the duplicates bucket +// provided with the minimal information required for successful reading. +func putDuplicatePayment(t *testing.T, duplicateBucket kvdb.RwBucket, + sequenceKey []byte, paymentHash lntypes.Hash, + preImg lntypes.Preimage) { + + paymentBucket, err := duplicateBucket.CreateBucketIfNotExists( + sequenceKey, + ) + require.NoError(t, err) + + err = paymentBucket.Put(duplicatePaymentSequenceKey, sequenceKey) + require.NoError(t, err) + + // Generate fake information for the duplicate payment. + info, _, _, err := genInfo(t) + require.NoError(t, err) + + // Write the payment info to disk under the creation info key. This code + // is copied rather than using serializePaymentCreationInfo to ensure + // we always write in the legacy format used by duplicate payments. + var b bytes.Buffer + var scratch [8]byte + _, err = b.Write(paymentHash[:]) + require.NoError(t, err) + + byteOrder.PutUint64(scratch[:], uint64(info.Value)) + _, err = b.Write(scratch[:]) + require.NoError(t, err) + + err = serializeTime(&b, info.CreationTime) + require.NoError(t, err) + + byteOrder.PutUint32(scratch[:4], 0) + _, err = b.Write(scratch[:4]) + require.NoError(t, err) + + // Get the PaymentCreationInfo. + err = paymentBucket.Put(duplicatePaymentCreationInfoKey, b.Bytes()) + require.NoError(t, err) + + // Duolicate payments are only stored for successes, so add the + // preimage. + err = paymentBucket.Put(duplicatePaymentSettleInfoKey, preImg[:]) + require.NoError(t, err) +} + +// TestQueryPayments tests retrieval of payments with forwards and reversed +// queries. +func TestQueryPayments(t *testing.T) { + // Define table driven test for QueryPayments. + // Test payments have sequence indices [1, 3, 4, 5, 6, 7]. + // Note that the payment with index 7 has the same payment hash as 6, + // and is stored in a nested bucket within payment 6 rather than being + // its own entry in the payments bucket. We do this to test retrieval + // of legacy payments. + tests := []struct { + name string + query Query + firstIndex uint64 + lastIndex uint64 + + // expectedSeqNrs contains the set of sequence numbers we expect + // our query to return. + expectedSeqNrs []uint64 + }{ + { + name: "IndexOffset at the end of the payments range", + query: Query{ + IndexOffset: 7, + MaxPayments: 7, + Reversed: false, + IncludeIncomplete: true, + }, + firstIndex: 0, + lastIndex: 0, + expectedSeqNrs: nil, + }, + { + name: "query in forwards order, start at beginning", + query: Query{ + IndexOffset: 0, + MaxPayments: 2, + Reversed: false, + IncludeIncomplete: true, + }, + firstIndex: 1, + lastIndex: 3, + expectedSeqNrs: []uint64{1, 3}, + }, + { + name: "query in forwards order, start at end, overflow", + query: Query{ + IndexOffset: 6, + MaxPayments: 2, + Reversed: false, + IncludeIncomplete: true, + }, + firstIndex: 7, + lastIndex: 7, + expectedSeqNrs: []uint64{7}, + }, + { + name: "start at offset index outside of payments", + query: Query{ + IndexOffset: 20, + MaxPayments: 2, + Reversed: false, + IncludeIncomplete: true, + }, + firstIndex: 0, + lastIndex: 0, + expectedSeqNrs: nil, + }, + { + name: "overflow in forwards order", + query: Query{ + IndexOffset: 4, + MaxPayments: math.MaxUint64, + Reversed: false, + IncludeIncomplete: true, + }, + firstIndex: 5, + lastIndex: 7, + expectedSeqNrs: []uint64{5, 6, 7}, + }, + { + name: "start at offset index outside of payments, " + + "reversed order", + query: Query{ + IndexOffset: 9, + MaxPayments: 2, + Reversed: true, + IncludeIncomplete: true, + }, + firstIndex: 6, + lastIndex: 7, + expectedSeqNrs: []uint64{6, 7}, + }, + { + name: "query in reverse order, start at end", + query: Query{ + IndexOffset: 0, + MaxPayments: 2, + Reversed: true, + IncludeIncomplete: true, + }, + firstIndex: 6, + lastIndex: 7, + expectedSeqNrs: []uint64{6, 7}, + }, + { + name: "query in reverse order, starting in middle", + query: Query{ + IndexOffset: 4, + MaxPayments: 2, + Reversed: true, + IncludeIncomplete: true, + }, + firstIndex: 1, + lastIndex: 3, + expectedSeqNrs: []uint64{1, 3}, + }, + { + name: "query in reverse order, starting in middle, " + + "with underflow", + query: Query{ + IndexOffset: 4, + MaxPayments: 5, + Reversed: true, + IncludeIncomplete: true, + }, + firstIndex: 1, + lastIndex: 3, + expectedSeqNrs: []uint64{1, 3}, + }, + { + name: "all payments in reverse, order maintained", + query: Query{ + IndexOffset: 0, + MaxPayments: 7, + Reversed: true, + IncludeIncomplete: true, + }, + firstIndex: 1, + lastIndex: 7, + expectedSeqNrs: []uint64{1, 3, 4, 5, 6, 7}, + }, + { + name: "exclude incomplete payments", + query: Query{ + IndexOffset: 0, + MaxPayments: 7, + Reversed: false, + IncludeIncomplete: false, + }, + firstIndex: 7, + lastIndex: 7, + expectedSeqNrs: []uint64{7}, + }, + { + name: "query payments at index gap", + query: Query{ + IndexOffset: 1, + MaxPayments: 7, + Reversed: false, + IncludeIncomplete: true, + }, + firstIndex: 3, + lastIndex: 7, + expectedSeqNrs: []uint64{3, 4, 5, 6, 7}, + }, + { + name: "query payments reverse before index gap", + query: Query{ + IndexOffset: 3, + MaxPayments: 7, + Reversed: true, + IncludeIncomplete: true, + }, + firstIndex: 1, + lastIndex: 1, + expectedSeqNrs: []uint64{1}, + }, + { + name: "query payments reverse on index gap", + query: Query{ + IndexOffset: 2, + MaxPayments: 7, + Reversed: true, + IncludeIncomplete: true, + }, + firstIndex: 1, + lastIndex: 1, + expectedSeqNrs: []uint64{1}, + }, + { + name: "query payments forward on index gap", + query: Query{ + IndexOffset: 2, + MaxPayments: 2, + Reversed: false, + IncludeIncomplete: true, + }, + firstIndex: 3, + lastIndex: 4, + expectedSeqNrs: []uint64{3, 4}, + }, + { + name: "query in forwards order, with start creation " + + "time", + query: Query{ + IndexOffset: 0, + MaxPayments: 2, + Reversed: false, + IncludeIncomplete: true, + CreationDateStart: 5, + }, + firstIndex: 5, + lastIndex: 6, + expectedSeqNrs: []uint64{5, 6}, + }, + { + name: "query in forwards order, with start creation " + + "time at end, overflow", + query: Query{ + IndexOffset: 0, + MaxPayments: 2, + Reversed: false, + IncludeIncomplete: true, + CreationDateStart: 7, + }, + firstIndex: 7, + lastIndex: 7, + expectedSeqNrs: []uint64{7}, + }, + { + name: "query with start and end creation time", + query: Query{ + IndexOffset: 9, + MaxPayments: math.MaxUint64, + Reversed: true, + IncludeIncomplete: true, + CreationDateStart: 3, + CreationDateEnd: 5, + }, + firstIndex: 3, + lastIndex: 5, + expectedSeqNrs: []uint64{3, 4, 5}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + paymentDB := NewKVTestDB(t) + + // Initialize the payment database. + paymentDB, err := NewKVPaymentsDB(paymentDB.db) + require.NoError(t, err) + + // Make a preliminary query to make sure it's ok to + // query when we have no payments. + resp, err := paymentDB.QueryPayments(ctx, tt.query) + require.NoError(t, err) + require.Len(t, resp.Payments, 0) + + // Populate the database with a set of test payments. + // We create 6 original payments, deleting the payment + // at index 2 so that we cover the case where sequence + // numbers are missing. We also add a duplicate payment + // to the last payment added to test the legacy case + // where we have duplicates in the nested duplicates + // bucket. + nonDuplicatePayments := 6 + + for i := 0; i < nonDuplicatePayments; i++ { + // Generate a test payment. + info, _, preimg, err := genInfo(t) + if err != nil { + t.Fatalf("unable to create test "+ + "payment: %v", err) + } + // Override creation time to allow for testing + // of CreationDateStart and CreationDateEnd. + info.CreationTime = time.Unix(int64(i+1), 0) + + // Create a new payment entry in the database. + err = paymentDB.InitPayment( + info.PaymentIdentifier, info, + ) + require.NoError(t, err) + + // Immediately delete the payment with index 2. + if i == 1 { + pmt, err := paymentDB.FetchPayment( + info.PaymentIdentifier, + ) + require.NoError(t, err) + + deletePayment( + t, paymentDB.db, + info.PaymentIdentifier, + pmt.SequenceNum, + ) + } + + // If we are on the last payment entry, add a + // duplicate payment with sequence number equal + // to the parent payment + 1. Note that + // duplicate payments will always be succeeded. + if i == (nonDuplicatePayments - 1) { + pmt, err := paymentDB.FetchPayment( + info.PaymentIdentifier, + ) + require.NoError(t, err) + + appendDuplicatePayment( + t, paymentDB.db, + info.PaymentIdentifier, + pmt.SequenceNum+1, + preimg, + ) + } + } + + // Fetch all payments in the database. + allPayments, err := paymentDB.FetchPayments() + if err != nil { + t.Fatalf("payments could not be fetched from "+ + "database: %v", err) + } + + if len(allPayments) != 6 { + t.Fatalf("Number of payments received does "+ + "not match expected one. Got %v, "+ + "want %v.", len(allPayments), 6) + } + + querySlice, err := paymentDB.QueryPayments( + ctx, tt.query, + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if tt.firstIndex != querySlice.FirstIndexOffset || + tt.lastIndex != querySlice.LastIndexOffset { + + t.Errorf("First or last index does not match "+ + "expected index. Want (%d, %d), "+ + "got (%d, %d).", + tt.firstIndex, tt.lastIndex, + querySlice.FirstIndexOffset, + querySlice.LastIndexOffset) + } + + if len(querySlice.Payments) != len(tt.expectedSeqNrs) { + t.Errorf("expected: %v payments, got: %v", + len(tt.expectedSeqNrs), + len(querySlice.Payments)) + } + + for i, seqNr := range tt.expectedSeqNrs { + q := querySlice.Payments[i] + if seqNr != q.SequenceNum { + t.Errorf("sequence numbers do not "+ + "match, got %v, want %v", + q.SequenceNum, seqNr) + } + } + }) + } +} + +// TestLazySessionKeyDeserialize tests that we can read htlc attempt session +// keys that were previously serialized as a private key as raw bytes. +func TestLazySessionKeyDeserialize(t *testing.T) { + var b bytes.Buffer + + // Serialize as a private key. + err := WriteElements(&b, priv) + require.NoError(t, err) + + // Deserialize into [btcec.PrivKeyBytesLen]byte. + attempt := HTLCAttemptInfo{} + err = ReadElements(&b, &attempt.sessionKey) + require.NoError(t, err) + require.Zero(t, b.Len()) + + sessionKey := attempt.SessionKey() + require.Equal(t, priv, sessionKey) +} diff --git a/payments/db/log.go b/payments/db/log.go index 08994efab89..8a77dbcec7f 100644 --- a/payments/db/log.go +++ b/payments/db/log.go @@ -8,8 +8,6 @@ import ( // log is a logger that is initialized with no output filters. This // means the package will not perform any logging by default until the caller // requests it. -// -//nolint:unused var log btclog.Logger // Subsystem defines the logging identifier for this subsystem. diff --git a/channeldb/mp_payment.go b/payments/db/payment.go similarity index 86% rename from channeldb/mp_payment.go rename to payments/db/payment.go index f75357afc0b..72338a4e50e 100644 --- a/channeldb/mp_payment.go +++ b/payments/db/payment.go @@ -1,4 +1,4 @@ -package channeldb +package paymentsdb import ( "bytes" @@ -15,10 +15,94 @@ import ( "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnutils" "github.com/lightningnetwork/lnd/lnwire" - paymentsdb "github.com/lightningnetwork/lnd/payments/db" "github.com/lightningnetwork/lnd/routing/route" ) +// FailureReason encodes the reason a payment ultimately failed. +type FailureReason byte + +const ( + // FailureReasonTimeout indicates that the payment did timeout before a + // successful payment attempt was made. + FailureReasonTimeout FailureReason = 0 + + // FailureReasonNoRoute indicates no successful route to the + // destination was found during path finding. + FailureReasonNoRoute FailureReason = 1 + + // FailureReasonError indicates that an unexpected error happened during + // payment. + FailureReasonError FailureReason = 2 + + // FailureReasonPaymentDetails indicates that either the hash is unknown + // or the final cltv delta or amount is incorrect. + FailureReasonPaymentDetails FailureReason = 3 + + // FailureReasonInsufficientBalance indicates that we didn't have enough + // balance to complete the payment. + FailureReasonInsufficientBalance FailureReason = 4 + + // FailureReasonCanceled indicates that the payment was canceled by the + // user. + FailureReasonCanceled FailureReason = 5 + + // TODO(joostjager): Add failure reasons for: + // LocalLiquidityInsufficient, RemoteCapacityInsufficient. +) + +// Error returns a human-readable error string for the FailureReason. +func (r FailureReason) Error() string { + return r.String() +} + +// String returns a human-readable FailureReason. +func (r FailureReason) String() string { + switch r { + case FailureReasonTimeout: + return "timeout" + case FailureReasonNoRoute: + return "no_route" + case FailureReasonError: + return "error" + case FailureReasonPaymentDetails: + return "incorrect_payment_details" + case FailureReasonInsufficientBalance: + return "insufficient_balance" + case FailureReasonCanceled: + return "canceled" + } + + return "unknown" +} + +// PaymentCreationInfo is the information necessary to have ready when +// initiating a payment, moving it into state InFlight. +type PaymentCreationInfo struct { + // PaymentIdentifier is the hash this payment is paying to in case of + // non-AMP payments, and the SetID for AMP payments. + PaymentIdentifier lntypes.Hash + + // Value is the amount we are paying. + Value lnwire.MilliSatoshi + + // CreationTime is the time when this payment was initiated. + CreationTime time.Time + + // PaymentRequest is the full payment request, if any. + PaymentRequest []byte + + // FirstHopCustomRecords are the TLV records that are to be sent to the + // first hop of this payment. These records will be transmitted via the + // wire message only and therefore do not affect the onion payload size. + FirstHopCustomRecords lnwire.CustomRecords +} + +// String returns a human-readable description of the payment creation info. +func (p *PaymentCreationInfo) String() string { + return fmt.Sprintf("payment_id=%v, amount=%v, created_at=%v", + p.PaymentIdentifier, p.Value, p.CreationTime) +} + // HTLCAttemptInfo contains static information about a specific HTLC attempt // for a payment. This information is used by the router to handle any errors // coming back after an attempt is made, and to query the switch about the @@ -171,8 +255,8 @@ const ( // reason. HTLCFailUnknown HTLCFailReason = 0 - // HTLCFailUnknown is recorded for htlcs that had a failure message that - // couldn't be decrypted. + // HTLCFailUnreadable is recorded for htlcs that had a failure message + // that couldn't be decrypted. HTLCFailUnreadable HTLCFailReason = 1 // HTLCFailInternal is recorded for htlcs that failed because of an @@ -350,12 +434,12 @@ func (m *MPPayment) Registrable() error { // are settled HTLCs or the payment is failed. If we already have // settled HTLCs, we won't allow adding more HTLCs. if m.State.HasSettledHTLC { - return paymentsdb.ErrPaymentPendingSettled + return ErrPaymentPendingSettled } // If the payment is already failed, we won't allow adding more HTLCs. if m.State.PaymentFailed { - return paymentsdb.ErrPaymentPendingFailed + return ErrPaymentPendingFailed } // Otherwise we can add more HTLCs. @@ -373,7 +457,7 @@ func (m *MPPayment) setState() error { totalAmt := m.Info.Value if sentAmt > totalAmt { return fmt.Errorf("%w: sent=%v, total=%v", - paymentsdb.ErrSentExceedsTotal, sentAmt, totalAmt) + ErrSentExceedsTotal, sentAmt, totalAmt) } // Get any terminal info for this payment. @@ -452,7 +536,7 @@ func (m *MPPayment) NeedWaitAttempts() (bool, error) { case StatusSucceeded: return false, fmt.Errorf("%w: parts of the payment "+ "already succeeded but still have remaining "+ - "amount %v", paymentsdb.ErrPaymentInternal, + "amount %v", ErrPaymentInternal, m.State.RemainingAmt) // The payment is failed and we have no inflight HTLCs, no need @@ -463,7 +547,7 @@ func (m *MPPayment) NeedWaitAttempts() (bool, error) { // Unknown payment status. default: return false, fmt.Errorf("%w: %s", - paymentsdb.ErrUnknownPaymentStatus, m.Status) + ErrUnknownPaymentStatus, m.Status) } } @@ -474,7 +558,7 @@ func (m *MPPayment) NeedWaitAttempts() (bool, error) { // amount, return an error. case StatusInitiated: return false, fmt.Errorf("%w: %v", - paymentsdb.ErrPaymentInternal, m.Status) + ErrPaymentInternal, m.Status) // If the payment is inflight, we must wait. // @@ -496,12 +580,12 @@ func (m *MPPayment) NeedWaitAttempts() (bool, error) { // not be zero because our sentAmt is zero. case StatusFailed: return false, fmt.Errorf("%w: %v", - paymentsdb.ErrPaymentInternal, m.Status) + ErrPaymentInternal, m.Status) // Unknown payment status. default: return false, fmt.Errorf("%w: %s", - paymentsdb.ErrUnknownPaymentStatus, m.Status) + ErrUnknownPaymentStatus, m.Status) } } @@ -510,12 +594,12 @@ func (m *MPPayment) GetState() *MPPaymentState { return m.State } -// Status returns the current status of the payment. +// GetStatus returns the current status of the payment. func (m *MPPayment) GetStatus() PaymentStatus { return m.Status } -// GetPayment returns all the HTLCs for this payment. +// GetHTLCs returns all the HTLCs for this payment. func (m *MPPayment) GetHTLCs() []HTLCAttempt { return m.HTLCs } @@ -532,7 +616,7 @@ func (m *MPPayment) AllowMoreAttempts() (bool, error) { if m.Status == StatusInitiated { return false, fmt.Errorf("%w: initiated payment has "+ "zero remainingAmt", - paymentsdb.ErrPaymentInternal) + ErrPaymentInternal) } // Otherwise, exit early since all other statuses with zero @@ -549,7 +633,7 @@ func (m *MPPayment) AllowMoreAttempts() (bool, error) { if m.Status == StatusSucceeded { return false, fmt.Errorf("%w: payment already succeeded but "+ "still have remaining amount %v", - paymentsdb.ErrPaymentInternal, m.State.RemainingAmt) + ErrPaymentInternal, m.State.RemainingAmt) } // Now check if we can register a new HTLC. @@ -654,39 +738,6 @@ func deserializeHTLCFailInfo(r io.Reader) (*HTLCFailInfo, error) { return f, nil } -// deserializeTime deserializes time as unix nanoseconds. -func deserializeTime(r io.Reader) (time.Time, error) { - var scratch [8]byte - if _, err := io.ReadFull(r, scratch[:]); err != nil { - return time.Time{}, err - } - - // Convert to time.Time. Interpret unix nano time zero as a zero - // time.Time value. - unixNano := byteOrder.Uint64(scratch[:]) - if unixNano == 0 { - return time.Time{}, nil - } - - return time.Unix(0, int64(unixNano)), nil -} - -// serializeTime serializes time as unix nanoseconds. -func serializeTime(w io.Writer, t time.Time) error { - var scratch [8]byte - - // Convert to unix nano seconds, but only if time is non-zero. Calling - // UnixNano() on a zero time yields an undefined result. - var unixNano int64 - if !t.IsZero() { - unixNano = t.UnixNano() - } - - byteOrder.PutUint64(scratch[:], uint64(unixNano)) - _, err := w.Write(scratch[:]) - return err -} - // generateSphinxPacket generates then encodes a sphinx packet which encodes // the onion route specified by the passed layer 3 route. The blob returned // from this function can immediately be included within an HTLC add packet to diff --git a/channeldb/payment_status.go b/payments/db/payment_status.go similarity index 93% rename from channeldb/payment_status.go rename to payments/db/payment_status.go index 179e22fc250..3fd01cafa36 100644 --- a/channeldb/payment_status.go +++ b/payments/db/payment_status.go @@ -1,9 +1,7 @@ -package channeldb +package paymentsdb import ( "fmt" - - paymentsdb "github.com/lightningnetwork/lnd/payments/db" ) // PaymentStatus represent current status of payment. @@ -62,24 +60,24 @@ func (ps PaymentStatus) initializable() error { // again in case other goroutines have already been creating HTLCs for // it. case StatusInitiated: - return paymentsdb.ErrPaymentExists + return ErrPaymentExists // We already have an InFlight payment on the network. We will disallow // any new payments. case StatusInFlight: - return paymentsdb.ErrPaymentInFlight + return ErrPaymentInFlight // The payment has been attempted and is succeeded so we won't allow // creating it again. case StatusSucceeded: - return paymentsdb.ErrAlreadyPaid + return ErrAlreadyPaid // We allow retrying failed payments. case StatusFailed: return nil default: - return fmt.Errorf("%w: %v", paymentsdb.ErrUnknownPaymentStatus, + return fmt.Errorf("%w: %v", ErrUnknownPaymentStatus, ps) } } @@ -96,7 +94,7 @@ func (ps PaymentStatus) removable() error { // There are still inflight HTLCs and the payment needs to wait for the // final outcomes. case StatusInFlight: - return paymentsdb.ErrPaymentInFlight + return ErrPaymentInFlight // The payment has been attempted and is succeeded and is allowed to be // removed. @@ -108,7 +106,7 @@ func (ps PaymentStatus) removable() error { return nil default: - return fmt.Errorf("%w: %v", paymentsdb.ErrUnknownPaymentStatus, + return fmt.Errorf("%w: %v", ErrUnknownPaymentStatus, ps) } } @@ -127,13 +125,13 @@ func (ps PaymentStatus) updatable() error { // If the payment has a terminal condition, we won't allow any updates. case StatusSucceeded: - return paymentsdb.ErrPaymentAlreadySucceeded + return ErrPaymentAlreadySucceeded case StatusFailed: - return paymentsdb.ErrPaymentAlreadyFailed + return ErrPaymentAlreadyFailed default: - return fmt.Errorf("%w: %v", paymentsdb.ErrUnknownPaymentStatus, + return fmt.Errorf("%w: %v", ErrUnknownPaymentStatus, ps) } } diff --git a/channeldb/payment_status_test.go b/payments/db/payment_status_test.go similarity index 92% rename from channeldb/payment_status_test.go rename to payments/db/payment_status_test.go index bd6181870c4..1bb4dc3889a 100644 --- a/channeldb/payment_status_test.go +++ b/payments/db/payment_status_test.go @@ -1,11 +1,10 @@ -package channeldb +package paymentsdb import ( "fmt" "testing" "github.com/lightningnetwork/lnd/lntypes" - paymentsdb "github.com/lightningnetwork/lnd/payments/db" "github.com/stretchr/testify/require" ) @@ -198,33 +197,33 @@ func TestPaymentStatusActions(t *testing.T) { }{ { status: StatusInitiated, - initErr: paymentsdb.ErrPaymentExists, + initErr: ErrPaymentExists, updateErr: nil, removeErr: nil, }, { status: StatusInFlight, - initErr: paymentsdb.ErrPaymentInFlight, + initErr: ErrPaymentInFlight, updateErr: nil, - removeErr: paymentsdb.ErrPaymentInFlight, + removeErr: ErrPaymentInFlight, }, { status: StatusSucceeded, - initErr: paymentsdb.ErrAlreadyPaid, - updateErr: paymentsdb.ErrPaymentAlreadySucceeded, + initErr: ErrAlreadyPaid, + updateErr: ErrPaymentAlreadySucceeded, removeErr: nil, }, { status: StatusFailed, initErr: nil, - updateErr: paymentsdb.ErrPaymentAlreadyFailed, + updateErr: ErrPaymentAlreadyFailed, removeErr: nil, }, { status: 0, - initErr: paymentsdb.ErrUnknownPaymentStatus, - updateErr: paymentsdb.ErrUnknownPaymentStatus, - removeErr: paymentsdb.ErrUnknownPaymentStatus, + initErr: ErrUnknownPaymentStatus, + updateErr: ErrUnknownPaymentStatus, + removeErr: ErrUnknownPaymentStatus, }, } diff --git a/channeldb/payments_kv_store_test.go b/payments/db/payment_test.go similarity index 58% rename from channeldb/payments_kv_store_test.go rename to payments/db/payment_test.go index d1953fa95d0..5e75dc5b9b5 100644 --- a/channeldb/payments_kv_store_test.go +++ b/payments/db/payment_test.go @@ -1,7 +1,7 @@ -package channeldb +package paymentsdb import ( - "bytes" + "context" "crypto/rand" "crypto/sha256" "errors" @@ -11,611 +11,531 @@ import ( "testing" "time" - "github.com/btcsuite/btcwallet/walletdb" + "github.com/btcsuite/btcd/btcec/v2" + "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/davecgh/go-spew/spew" - "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" - paymentsdb "github.com/lightningnetwork/lnd/payments/db" "github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/routing/route" - "github.com/lightningnetwork/lnd/tlv" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func genPreimage() ([32]byte, error) { - var preimage [32]byte - if _, err := io.ReadFull(rand.Reader, preimage[:]); err != nil { - return preimage, err +var ( + testHash = [32]byte{ + 0xb7, 0x94, 0x38, 0x5f, 0x2d, 0x1e, 0xf7, 0xab, + 0x4d, 0x92, 0x73, 0xd1, 0x90, 0x63, 0x81, 0xb4, + 0x4f, 0x2f, 0x6f, 0x25, 0x88, 0xa3, 0xef, 0xb9, + 0x6a, 0x49, 0x18, 0x83, 0x31, 0x98, 0x47, 0x53, } - return preimage, nil -} -func genInfo(t *testing.T) (*PaymentCreationInfo, *HTLCAttemptInfo, - lntypes.Preimage, error) { + rev = [chainhash.HashSize]byte{ + 0x51, 0xb6, 0x37, 0xd8, 0xfc, 0xd2, 0xc6, 0xda, + 0x48, 0x59, 0xe6, 0x96, 0x31, 0x13, 0xa1, 0x17, + 0x2d, 0xe7, 0x93, 0xe4, + } +) - preimage, err := genPreimage() - if err != nil { - return nil, nil, preimage, fmt.Errorf("unable to "+ - "generate preimage: %v", err) +var ( + priv, _ = btcec.NewPrivateKey() + pub = priv.PubKey() + vertex = route.NewVertex(pub) + + testHop1 = &route.Hop{ + PubKeyBytes: vertex, + ChannelID: 12345, + OutgoingTimeLock: 111, + AmtToForward: 555, + CustomRecords: record.CustomSet{ + 65536: []byte{}, + 80001: []byte{}, + }, + MPP: record.NewMPP(32, [32]byte{0x42}), + Metadata: []byte{1, 2, 3}, } - rhash := sha256.Sum256(preimage[:]) - var hash lntypes.Hash - copy(hash[:], rhash[:]) + testHop2 = &route.Hop{ + PubKeyBytes: vertex, + ChannelID: 12345, + OutgoingTimeLock: 111, + AmtToForward: 555, + LegacyPayload: true, + } - attempt, err := NewHtlcAttempt( - 0, priv, *testRoute.Copy(), time.Time{}, &hash, - ) - require.NoError(t, err) + testRoute = route.Route{ + TotalTimeLock: 123, + TotalAmount: 1234567, + SourcePubKey: vertex, + Hops: []*route.Hop{ + testHop2, + testHop1, + }, + } - return &PaymentCreationInfo{ - PaymentIdentifier: rhash, - Value: testRoute.ReceiverAmt(), - CreationTime: time.Unix(time.Now().Unix(), 0), - PaymentRequest: []byte("hola"), - }, &attempt.HTLCAttemptInfo, preimage, nil + testBlindedRoute = route.Route{ + TotalTimeLock: 150, + TotalAmount: 1000, + SourcePubKey: vertex, + Hops: []*route.Hop{ + { + PubKeyBytes: vertex, + ChannelID: 9876, + OutgoingTimeLock: 120, + AmtToForward: 900, + EncryptedData: []byte{1, 3, 3}, + BlindingPoint: pub, + }, + { + PubKeyBytes: vertex, + EncryptedData: []byte{3, 2, 1}, + }, + { + PubKeyBytes: vertex, + Metadata: []byte{4, 5, 6}, + AmtToForward: 500, + OutgoingTimeLock: 100, + TotalAmtMsat: 500, + }, + }, + } +) + +// payment is a helper structure that holds basic information on a test payment, +// such as the payment id, the status and the total number of HTLCs attempted. +type payment struct { + id lntypes.Hash + status PaymentStatus + htlcs int } -// TestKVPaymentsDBSwitchFail checks that payment status returns to Failed -// status after failing, and that InitPayment allows another HTLC for the -// same payment hash. -func TestKVPaymentsDBSwitchFail(t *testing.T) { - t.Parallel() +// createTestPayments registers payments depending on the provided statuses in +// the payments slice. Each payment will receive one failed HTLC and another +// HTLC depending on the final status of the payment provided. +func createTestPayments(t *testing.T, p PaymentDB, payments []*payment) { + attemptID := uint64(0) - db, err := MakeTestDB(t) - require.NoError(t, err, "unable to init db") + for i := 0; i < len(payments); i++ { + info, attempt, preimg, err := genInfo(t) + require.NoError(t, err, "unable to generate htlc message") - paymentDB, err := NewKVPaymentsDB(db) - require.NoError(t, err) + // Set the payment id accordingly in the payments slice. + payments[i].id = info.PaymentIdentifier - info, attempt, preimg, err := genInfo(t) - require.NoError(t, err, "unable to generate htlc message") + attempt.AttemptID = attemptID + attemptID++ - // Sends base htlc message which initiate StatusInFlight. - err = paymentDB.InitPayment(info.PaymentIdentifier, info) - require.NoError(t, err, "unable to send htlc message") + // Init the payment. + err = p.InitPayment(info.PaymentIdentifier, info) + require.NoError(t, err, "unable to send htlc message") - assertPaymentIndex(t, paymentDB, info.PaymentIdentifier) - assertPaymentStatus( - t, paymentDB, info.PaymentIdentifier, StatusInitiated, - ) - assertPaymentInfo( - t, paymentDB, info.PaymentIdentifier, info, nil, nil, - ) + // Register and fail the first attempt for all payments. + _, err = p.RegisterAttempt(info.PaymentIdentifier, attempt) + require.NoError(t, err, "unable to send htlc message") - // Fail the payment, which should moved it to Failed. - failReason := FailureReasonNoRoute - _, err = paymentDB.Fail(info.PaymentIdentifier, failReason) - require.NoError(t, err, "unable to fail payment hash") + htlcFailure := HTLCFailUnreadable + _, err = p.FailAttempt( + info.PaymentIdentifier, attempt.AttemptID, + &HTLCFailInfo{ + Reason: htlcFailure, + }, + ) + require.NoError(t, err, "unable to fail htlc") - // Verify the status is indeed Failed. - assertPaymentStatus(t, paymentDB, info.PaymentIdentifier, StatusFailed) - assertPaymentInfo( - t, paymentDB, info.PaymentIdentifier, info, &failReason, nil, - ) + // Increase the HTLC counter in the payments slice for the + // failed attempt. + payments[i].htlcs++ - // Lookup the payment so we can get its old sequence number before it is - // overwritten. - payment, err := paymentDB.FetchPayment(info.PaymentIdentifier) - require.NoError(t, err) + // Depending on the test case, fail or succeed the next + // attempt. + attempt.AttemptID = attemptID + attemptID++ - // Sends the htlc again, which should succeed since the prior payment - // failed. - err = paymentDB.InitPayment(info.PaymentIdentifier, info) - require.NoError(t, err, "unable to send htlc message") + _, err = p.RegisterAttempt(info.PaymentIdentifier, attempt) + require.NoError(t, err, "unable to send htlc message") - // Check that our index has been updated, and the old index has been - // removed. - assertPaymentIndex(t, paymentDB, info.PaymentIdentifier) - assertNoIndex(t, paymentDB, payment.SequenceNum) + switch payments[i].status { + // Fail the attempt and the payment overall. + case StatusFailed: + htlcFailure := HTLCFailUnreadable + _, err = p.FailAttempt( + info.PaymentIdentifier, attempt.AttemptID, + &HTLCFailInfo{ + Reason: htlcFailure, + }, + ) + require.NoError(t, err, "unable to fail htlc") - assertPaymentStatus( - t, paymentDB, info.PaymentIdentifier, StatusInitiated, - ) - assertPaymentInfo( - t, paymentDB, info.PaymentIdentifier, info, nil, nil, - ) + failReason := FailureReasonNoRoute + _, err = p.Fail(info.PaymentIdentifier, + failReason) + require.NoError(t, err, "unable to fail payment hash") - // Record a new attempt. In this test scenario, the attempt fails. - // However, this is not communicated to control tower in the current - // implementation. It only registers the initiation of the attempt. - _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, attempt) - require.NoError(t, err, "unable to register attempt") + // Settle the attempt + case StatusSucceeded: + _, err := p.SettleAttempt( + info.PaymentIdentifier, attempt.AttemptID, + &HTLCSettleInfo{ + Preimage: preimg, + }, + ) + require.NoError(t, err, "no error should have been "+ + "received from settling a htlc attempt") - htlcReason := HTLCFailUnreadable - _, err = paymentDB.FailAttempt( - info.PaymentIdentifier, attempt.AttemptID, - &HTLCFailInfo{ - Reason: htlcReason, - }, - ) - if err != nil { - t.Fatal(err) - } - assertPaymentStatus( - t, paymentDB, info.PaymentIdentifier, StatusInFlight, - ) + // We leave the attempt in-flight by doing nothing. + case StatusInFlight: + } - htlc := &htlcStatus{ - HTLCAttemptInfo: attempt, - failure: &htlcReason, + // Increase the HTLC counter in the payments slice for any + // attempt above. + payments[i].htlcs++ } +} - assertPaymentInfo(t, paymentDB, info.PaymentIdentifier, info, nil, htlc) - - // Record another attempt. - attempt.AttemptID = 1 - _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, attempt) - require.NoError(t, err, "unable to send htlc message") - assertPaymentStatus( - t, paymentDB, info.PaymentIdentifier, StatusInFlight, - ) - - htlc = &htlcStatus{ - HTLCAttemptInfo: attempt, +// assertRouteEquals compares to routes for equality and returns an error if +// they are not equal. +func assertRouteEqual(a, b *route.Route) error { + if !reflect.DeepEqual(a, b) { + return fmt.Errorf("HTLCAttemptInfos don't match: %v vs %v", + spew.Sdump(a), spew.Sdump(b)) } - assertPaymentInfo( - t, paymentDB, info.PaymentIdentifier, info, nil, htlc, - ) + return nil +} - // Settle the attempt and verify that status was changed to - // StatusSucceeded. - payment, err = paymentDB.SettleAttempt( - info.PaymentIdentifier, attempt.AttemptID, - &HTLCSettleInfo{ - Preimage: preimg, - }, - ) - require.NoError(t, err, "error shouldn't have been received, got") +// assertPaymentInfo retrieves the payment referred to by hash and verifies the +// expected values. +func assertPaymentInfo(t *testing.T, p PaymentDB, hash lntypes.Hash, + c *PaymentCreationInfo, f *FailureReason, + a *htlcStatus) { - if len(payment.HTLCs) != 2 { - t.Fatalf("payment should have two htlcs, got: %d", - len(payment.HTLCs)) - } + t.Helper() - err = assertRouteEqual(&payment.HTLCs[0].Route, &attempt.Route) + payment, err := p.FetchPayment(hash) if err != nil { - t.Fatalf("unexpected route returned: %v vs %v: %v", - spew.Sdump(attempt.Route), - spew.Sdump(payment.HTLCs[0].Route), err) + t.Fatal(err) } - assertPaymentStatus( - t, paymentDB, info.PaymentIdentifier, StatusSucceeded, - ) - - htlc.settle = &preimg - assertPaymentInfo( - t, paymentDB, info.PaymentIdentifier, info, nil, htlc, - ) + if !reflect.DeepEqual(payment.Info, c) { + t.Fatalf("PaymentCreationInfos don't match: %v vs %v", + spew.Sdump(payment.Info), spew.Sdump(c)) + } - // Attempt a final payment, which should now fail since the prior - // payment succeed. - err = paymentDB.InitPayment(info.PaymentIdentifier, info) - if !errors.Is(err, paymentsdb.ErrAlreadyPaid) { - t.Fatalf("unable to send htlc message: %v", err) + if f != nil { + if *payment.FailureReason != *f { + t.Fatal("unexpected failure reason") + } + } else { + if payment.FailureReason != nil { + t.Fatal("unexpected failure reason") + } } -} -// TestKVPaymentsDBSwitchDoubleSend checks the ability of payment control to -// prevent double sending of htlc message, when message is in StatusInFlight. -func TestKVPaymentsDBSwitchDoubleSend(t *testing.T) { - t.Parallel() + if a == nil { + if len(payment.HTLCs) > 0 { + t.Fatal("expected no htlcs") + } - db, err := MakeTestDB(t) - require.NoError(t, err, "unable to init db") + return + } - paymentDB, err := NewKVPaymentsDB(db) - require.NoError(t, err) + htlc := payment.HTLCs[a.AttemptID] + if err := assertRouteEqual(&htlc.Route, &a.Route); err != nil { + t.Fatal("routes do not match") + } - info, attempt, preimg, err := genInfo(t) - require.NoError(t, err, "unable to generate htlc message") + if htlc.AttemptID != a.AttemptID { + t.Fatalf("unnexpected attempt ID %v, expected %v", + htlc.AttemptID, a.AttemptID) + } - // Sends base htlc message which initiate base status and move it to - // StatusInFlight and verifies that it was changed. - err = paymentDB.InitPayment(info.PaymentIdentifier, info) - require.NoError(t, err, "unable to send htlc message") + if a.failure != nil { + if htlc.Failure == nil { + t.Fatalf("expected HTLC to be failed") + } - assertPaymentIndex(t, paymentDB, info.PaymentIdentifier) - assertPaymentStatus( - t, paymentDB, info.PaymentIdentifier, StatusInitiated, - ) - assertPaymentInfo( - t, paymentDB, info.PaymentIdentifier, info, nil, nil, - ) + if htlc.Failure.Reason != *a.failure { + t.Fatalf("expected HTLC failure %v, had %v", + *a.failure, htlc.Failure.Reason) + } + } else if htlc.Failure != nil { + t.Fatalf("expected no HTLC failure") + } - // Try to initiate double sending of htlc message with the same - // payment hash, should result in error indicating that payment has - // already been sent. - err = paymentDB.InitPayment(info.PaymentIdentifier, info) - require.ErrorIs(t, err, paymentsdb.ErrPaymentExists) + if a.settle != nil { + if htlc.Settle.Preimage != *a.settle { + t.Fatalf("Preimages don't match: %x vs %x", + htlc.Settle.Preimage, a.settle) + } + } else if htlc.Settle != nil { + t.Fatal("expected no settle info") + } +} - // Record an attempt. - _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, attempt) - require.NoError(t, err, "unable to send htlc message") - assertPaymentStatus( - t, paymentDB, info.PaymentIdentifier, StatusInFlight, - ) +// assertPaymentStatus retrieves the status of the payment referred to by hash +// and compares it with the expected state. +func assertPaymentStatus(t *testing.T, p PaymentDB, + hash lntypes.Hash, expStatus PaymentStatus) { - htlc := &htlcStatus{ - HTLCAttemptInfo: attempt, + t.Helper() + + payment, err := p.FetchPayment(hash) + if errors.Is(err, ErrPaymentNotInitiated) { + return + } + if err != nil { + t.Fatal(err) } - assertPaymentInfo( - t, paymentDB, info.PaymentIdentifier, info, nil, htlc, - ) - // Sends base htlc message which initiate StatusInFlight. - err = paymentDB.InitPayment(info.PaymentIdentifier, info) - if !errors.Is(err, paymentsdb.ErrPaymentInFlight) { - t.Fatalf("payment control wrong behaviour: " + - "double sending must trigger ErrPaymentInFlight error") + if payment.Status != expStatus { + t.Fatalf("payment status mismatch: expected %v, got %v", + expStatus, payment.Status) } +} - // After settling, the error should be ErrAlreadyPaid. - _, err = paymentDB.SettleAttempt( - info.PaymentIdentifier, attempt.AttemptID, - &HTLCSettleInfo{ - Preimage: preimg, +// assertPayments is a helper function that given a slice of payment and +// indices for the slice asserts that exactly the same payments in the +// slice for the provided indices exist when fetching payments from the +// database. +func assertPayments(t *testing.T, paymentDB PaymentDB, + payments []*payment) { + + t.Helper() + + response, err := paymentDB.QueryPayments( + context.Background(), Query{ + IndexOffset: 0, + MaxPayments: uint64(len(payments)), + IncludeIncomplete: true, }, ) - require.NoError(t, err, "error shouldn't have been received, got") - assertPaymentStatus( - t, paymentDB, info.PaymentIdentifier, StatusSucceeded, - ) + require.NoError(t, err, "could not fetch payments from db") - htlc.settle = &preimg - assertPaymentInfo( - t, paymentDB, info.PaymentIdentifier, info, nil, htlc, + dbPayments := response.Payments + + // Make sure that the number of fetched payments is the same + // as expected. + require.Len( + t, dbPayments, len(payments), "unexpected number of payments", ) - err = paymentDB.InitPayment(info.PaymentIdentifier, info) - if !errors.Is(err, paymentsdb.ErrAlreadyPaid) { - t.Fatalf("unable to send htlc message: %v", err) + // Convert fetched payments of type MPPayment to our helper structure. + p := make([]*payment, len(dbPayments)) + for i, dbPayment := range dbPayments { + p[i] = &payment{ + id: dbPayment.Info.PaymentIdentifier, + status: dbPayment.Status, + htlcs: len(dbPayment.HTLCs), + } } + + // Check that each payment we want to assert exists in the database. + require.Equal(t, payments, p) } -// TestKVPaymentsDBSuccessesWithoutInFlight checks that the payment -// control will disallow calls to Success when no payment is in flight. -func TestKVPaymentsDBSuccessesWithoutInFlight(t *testing.T) { - t.Parallel() +func genPreimage() ([32]byte, error) { + var preimage [32]byte + if _, err := io.ReadFull(rand.Reader, preimage[:]); err != nil { + return preimage, err + } - db, err := MakeTestDB(t) - require.NoError(t, err, "unable to init db") + return preimage, nil +} - paymentDB, err := NewKVPaymentsDB(db) - require.NoError(t, err) +func genInfo(t *testing.T) (*PaymentCreationInfo, *HTLCAttemptInfo, + lntypes.Preimage, error) { - info, _, preimg, err := genInfo(t) - require.NoError(t, err, "unable to generate htlc message") + preimage, err := genPreimage() + if err != nil { + return nil, nil, preimage, fmt.Errorf("unable to "+ + "generate preimage: %v", err) + } - // Attempt to complete the payment should fail. - _, err = paymentDB.SettleAttempt( - info.PaymentIdentifier, 0, - &HTLCSettleInfo{ - Preimage: preimg, - }, + rhash := sha256.Sum256(preimage[:]) + var hash lntypes.Hash + copy(hash[:], rhash[:]) + + attempt, err := NewHtlcAttempt( + 0, priv, *testRoute.Copy(), time.Time{}, &hash, ) - require.ErrorIs(t, err, paymentsdb.ErrPaymentNotInitiated) + require.NoError(t, err) + + return &PaymentCreationInfo{ + PaymentIdentifier: rhash, + Value: testRoute.ReceiverAmt(), + CreationTime: time.Unix(time.Now().Unix(), 0), + PaymentRequest: []byte("hola"), + }, &attempt.HTLCAttemptInfo, preimage, nil } -// TestKVPaymentsDBFailsWithoutInFlight checks that a strict payment -// control will disallow calls to Fail when no payment is in flight. -func TestKVPaymentsDBFailsWithoutInFlight(t *testing.T) { +// TestDeleteFailedAttempts checks that DeleteFailedAttempts properly removes +// failed HTLCs from finished payments. +func TestDeleteFailedAttempts(t *testing.T) { t.Parallel() - db, err := MakeTestDB(t) - require.NoError(t, err, "unable to init db") + t.Run("keep failed payment attempts", func(t *testing.T) { + testDeleteFailedAttempts(t, true) + }) + t.Run("remove failed payment attempts", func(t *testing.T) { + testDeleteFailedAttempts(t, false) + }) +} - paymentDB, err := NewKVPaymentsDB(db) - require.NoError(t, err) +func testDeleteFailedAttempts(t *testing.T, keepFailedPaymentAttempts bool) { + paymentDB := NewTestDB( + t, WithKeepFailedPaymentAttempts(keepFailedPaymentAttempts), + ) - info, _, _, err := genInfo(t) - require.NoError(t, err, "unable to generate htlc message") + // Register three payments: + // All payments will have one failed HTLC attempt and one HTLC attempt + // according to its final status. + // 1. A payment with two failed attempts. + // 2. A payment with one failed and one in-flight attempt. + // 3. A payment with one failed and one settled attempt. - // Calling Fail should return an error. - _, err = paymentDB.Fail(info.PaymentIdentifier, FailureReasonNoRoute) - require.ErrorIs(t, err, paymentsdb.ErrPaymentNotInitiated) -} + // Initiate payments, which is a slice of payment that is used as + // template to create the corresponding test payments in the database. + // + // Note: The payment id and number of htlc attempts of each payment will + // be added to this slice when creating the payments below. + // This allows the slice to be used directly for testing purposes. + payments := []*payment{ + {status: StatusFailed}, + {status: StatusInFlight}, + {status: StatusSucceeded}, + } -// TestKVPaymentsDBDeleteNonInFlight checks that calling DeletePayments only -// deletes payments from the database that are not in-flight. -func TestKVPaymentsDBDeleteNonInFlight(t *testing.T) { - t.Parallel() + // Use helper function to register the test payments in the data and + // populate the data to the payments slice. + createTestPayments(t, paymentDB, payments) - db, err := MakeTestDB(t) - require.NoError(t, err, "unable to init db") + // Check that all payments are there as we added them. + assertPayments(t, paymentDB, payments) - // Create a sequence number for duplicate payments that will not collide - // with the sequence numbers for the payments we create. These values - // start at 1, so 9999 is a safe bet for this test. - var duplicateSeqNr = 9999 + // Calling DeleteFailedAttempts on a failed payment should delete all + // HTLCs. + require.NoError(t, paymentDB.DeleteFailedAttempts(payments[0].id)) - paymentDB, err := NewKVPaymentsDB(db) - require.NoError(t, err) + // Expect all HTLCs to be deleted if the config is set to delete them. + if !keepFailedPaymentAttempts { + payments[0].htlcs = 0 + } + assertPayments(t, paymentDB, payments) - payments := []struct { - failed bool - success bool - hasDuplicate bool - }{ - { - failed: true, - success: false, - hasDuplicate: false, - }, - { - failed: false, - success: true, - hasDuplicate: false, - }, - { - failed: false, - success: false, - hasDuplicate: false, - }, - { - failed: false, - success: true, - hasDuplicate: true, - }, + // Calling DeleteFailedAttempts on an in-flight payment should return + // an error. + if keepFailedPaymentAttempts { + require.NoError( + t, paymentDB.DeleteFailedAttempts(payments[1].id), + ) + } else { + require.Error(t, paymentDB.DeleteFailedAttempts(payments[1].id)) } - var numSuccess, numInflight int + // Since DeleteFailedAttempts returned an error, we should expect the + // payment to be unchanged. + assertPayments(t, paymentDB, payments) - for _, p := range payments { - info, attempt, preimg, err := genInfo(t) - if err != nil { - t.Fatalf("unable to generate htlc message: %v", err) - } + // Cleaning up a successful payment should remove failed htlcs. + require.NoError(t, paymentDB.DeleteFailedAttempts(payments[2].id)) + // Expect all HTLCs except for the settled one to be deleted if the + // config is set to delete them. + if !keepFailedPaymentAttempts { + payments[2].htlcs = 1 + } + assertPayments(t, paymentDB, payments) - // Sends base htlc message which initiate StatusInFlight. - err = paymentDB.InitPayment(info.PaymentIdentifier, info) - if err != nil { - t.Fatalf("unable to send htlc message: %v", err) - } - _, err = paymentDB.RegisterAttempt( - info.PaymentIdentifier, attempt, + if keepFailedPaymentAttempts { + // DeleteFailedAttempts is ignored, even for non-existent + // payments, if the control tower is configured to keep failed + // HTLCs. + require.NoError( + t, paymentDB.DeleteFailedAttempts(lntypes.ZeroHash), ) - if err != nil { - t.Fatalf("unable to send htlc message: %v", err) - } + } else { + // Attempting to cleanup a non-existent payment returns an + // error. + require.Error( + t, paymentDB.DeleteFailedAttempts(lntypes.ZeroHash), + ) + } +} - htlc := &htlcStatus{ - HTLCAttemptInfo: attempt, - } +// TestPaymentMPPRecordValidation tests MPP record validation. +func TestPaymentMPPRecordValidation(t *testing.T) { + t.Parallel() - if p.failed { - // Fail the payment attempt. - htlcFailure := HTLCFailUnreadable - _, err := paymentDB.FailAttempt( - info.PaymentIdentifier, attempt.AttemptID, - &HTLCFailInfo{ - Reason: htlcFailure, - }, - ) - if err != nil { - t.Fatalf("unable to fail htlc: %v", err) - } + paymentDB := NewTestDB(t) - // Fail the payment, which should moved it to Failed. - failReason := FailureReasonNoRoute - _, err = paymentDB.Fail( - info.PaymentIdentifier, failReason, - ) - if err != nil { - t.Fatalf("unable to fail payment hash: %v", err) - } - - // Verify the status is indeed Failed. - assertPaymentStatus( - t, paymentDB, info.PaymentIdentifier, - StatusFailed, - ) - - htlc.failure = &htlcFailure - assertPaymentInfo( - t, paymentDB, info.PaymentIdentifier, info, - &failReason, htlc, - ) - } else if p.success { - // Verifies that status was changed to StatusSucceeded. - _, err := paymentDB.SettleAttempt( - info.PaymentIdentifier, attempt.AttemptID, - &HTLCSettleInfo{ - Preimage: preimg, - }, - ) - if err != nil { - t.Fatalf("error shouldn't have been received,"+ - " got: %v", err) - } - - assertPaymentStatus( - t, paymentDB, info.PaymentIdentifier, - StatusSucceeded, - ) - - htlc.settle = &preimg - assertPaymentInfo( - t, paymentDB, info.PaymentIdentifier, info, nil, - htlc, - ) - - numSuccess++ - } else { - assertPaymentStatus( - t, paymentDB, info.PaymentIdentifier, - StatusInFlight, - ) - assertPaymentInfo( - t, paymentDB, info.PaymentIdentifier, info, nil, - htlc, - ) - - numInflight++ - } - - // If the payment is intended to have a duplicate payment, we - // add one. - if p.hasDuplicate { - appendDuplicatePayment( - t, paymentDB.db, info.PaymentIdentifier, - uint64(duplicateSeqNr), preimg, - ) - duplicateSeqNr++ - numSuccess++ - } - } - - // Delete all failed payments. - numPayments, err := paymentDB.DeletePayments(true, false) - require.NoError(t, err) - require.EqualValues(t, 1, numPayments) - - // This should leave the succeeded and in-flight payments. - dbPayments, err := paymentDB.FetchPayments() - if err != nil { - t.Fatal(err) - } - - if len(dbPayments) != numSuccess+numInflight { - t.Fatalf("expected %d payments, got %d", - numSuccess+numInflight, len(dbPayments)) - } - - var s, i int - for _, p := range dbPayments { - t.Log("fetch payment has status", p.Status) - switch p.Status { - case StatusSucceeded: - s++ - case StatusInFlight: - i++ - } - } - - if s != numSuccess { - t.Fatalf("expected %d succeeded payments , got %d", - numSuccess, s) - } - if i != numInflight { - t.Fatalf("expected %d in-flight payments, got %d", - numInflight, i) - } - - // Now delete all payments except in-flight. - numPayments, err = paymentDB.DeletePayments(false, false) - require.NoError(t, err) - require.EqualValues(t, 2, numPayments) - - // This should leave the in-flight payment. - dbPayments, err = paymentDB.FetchPayments() - if err != nil { - t.Fatal(err) - } - - if len(dbPayments) != numInflight { - t.Fatalf("expected %d payments, got %d", numInflight, - len(dbPayments)) - } - - for _, p := range dbPayments { - if p.Status != StatusInFlight { - t.Fatalf("expected in-fligth status, got %v", p.Status) - } - } - - // Finally, check that we only have a single index left in the payment - // index bucket. - var indexCount int - err = kvdb.View(db, func(tx walletdb.ReadTx) error { - index := tx.ReadBucket(paymentsIndexBucket) - - return index.ForEach(func(k, v []byte) error { - indexCount++ - return nil - }) - }, func() { indexCount = 0 }) - require.NoError(t, err) - - require.Equal(t, 1, indexCount) -} - -// TestKVPaymentsDBDeletePayments tests that DeletePayments correctly deletes -// information about completed payments from the database. -func TestKVPaymentsDBDeletePayments(t *testing.T) { - t.Parallel() - - db, err := MakeTestDB(t) - require.NoError(t, err, "unable to init db") - - paymentDB, err := NewKVPaymentsDB(db) - require.NoError(t, err) + info, attempt, _, err := genInfo(t) + require.NoError(t, err, "unable to generate htlc message") - // Register three payments: - // 1. A payment with two failed attempts. - // 2. A payment with one failed and one settled attempt. - // 3. A payment with one failed and one in-flight attempt. - payments := []*payment{ - {status: StatusFailed}, - {status: StatusSucceeded}, - {status: StatusInFlight}, - } + // Init the payment. + err = paymentDB.InitPayment(info.PaymentIdentifier, info) + require.NoError(t, err, "unable to send htlc message") - // Use helper function to register the test payments in the data and - // populate the data to the payments slice. - createTestPayments(t, paymentDB, payments) + // Create three unique attempts we'll use for the test, and + // register them with the payment control. We set each + // attempts's value to one third of the payment amount, and + // populate the MPP options. + shardAmt := info.Value / 3 + attempt.Route.FinalHop().AmtToForward = shardAmt + attempt.Route.FinalHop().MPP = record.NewMPP( + info.Value, [32]byte{1}, + ) - // Check that all payments are there as we added them. - assertPayments(t, paymentDB, payments) + _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, attempt) + require.NoError(t, err, "unable to send htlc message") - // Delete HTLC attempts for failed payments only. - numPayments, err := paymentDB.DeletePayments(true, true) - require.NoError(t, err) - require.EqualValues(t, 0, numPayments) + // Now try to register a non-MPP attempt, which should fail. + b := *attempt + b.AttemptID = 1 + b.Route.FinalHop().MPP = nil + _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, &b) + require.ErrorIs(t, err, ErrMPPayment) - // The failed payment is the only altered one. - payments[0].htlcs = 0 - assertPayments(t, paymentDB, payments) + // Try to register attempt one with a different payment address. + b.Route.FinalHop().MPP = record.NewMPP( + info.Value, [32]byte{2}, + ) + _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, &b) + require.ErrorIs(t, err, ErrMPPPaymentAddrMismatch) - // Delete failed attempts for all payments. - numPayments, err = paymentDB.DeletePayments(false, true) - require.NoError(t, err) - require.EqualValues(t, 0, numPayments) + // Try registering one with a different total amount. + b.Route.FinalHop().MPP = record.NewMPP( + info.Value/2, [32]byte{1}, + ) + _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, &b) + require.ErrorIs(t, err, ErrMPPTotalAmountMismatch) - // The failed attempts should be deleted, except for the in-flight - // payment, that shouldn't be altered until it has completed. - payments[1].htlcs = 1 - assertPayments(t, paymentDB, payments) + // Create and init a new payment. This time we'll check that we cannot + // register an MPP attempt if we already registered a non-MPP one. + info, attempt, _, err = genInfo(t) + require.NoError(t, err, "unable to generate htlc message") - // Now delete all failed payments. - numPayments, err = paymentDB.DeletePayments(true, false) - require.NoError(t, err) - require.EqualValues(t, 1, numPayments) + err = paymentDB.InitPayment(info.PaymentIdentifier, info) + require.NoError(t, err, "unable to send htlc message") - assertPayments(t, paymentDB, payments[1:]) + attempt.Route.FinalHop().MPP = nil + _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, attempt) + require.NoError(t, err, "unable to send htlc message") - // Finally delete all completed payments. - numPayments, err = paymentDB.DeletePayments(false, false) - require.NoError(t, err) - require.EqualValues(t, 1, numPayments) + // Attempt to register an MPP attempt, which should fail. + b = *attempt + b.AttemptID = 1 + b.Route.FinalHop().MPP = record.NewMPP( + info.Value, [32]byte{1}, + ) - assertPayments(t, paymentDB, payments[2:]) + _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, &b) + require.ErrorIs(t, err, ErrNonMPPayment) } -// TestKVPaymentsDBDeleteSinglePayment tests that DeletePayment correctly +// TestDeleteSinglePayment tests that DeletePayment correctly // deletes information about a completed payment from the database. -func TestKVPaymentsDBDeleteSinglePayment(t *testing.T) { +func TestDeleteSinglePayment(t *testing.T) { t.Parallel() - db, err := MakeTestDB(t) - require.NoError(t, err, "unable to init db") - - paymentDB, err := NewKVPaymentsDB(db) - require.NoError(t, err) + paymentDB := NewTestDB(t) // Register four payments: // All payments will have one failed HTLC attempt and one HTLC attempt @@ -694,1109 +614,1164 @@ func TestKVPaymentsDBDeleteSinglePayment(t *testing.T) { assertPayments(t, paymentDB, payments[3:]) } -// TestKVPaymentsDBMultiShard checks the ability of payment control to -// have multiple in-flight HTLCs for a single payment. -func TestKVPaymentsDBMultiShard(t *testing.T) { +// TestPaymentRegistrable checks the method `Registrable` behaves as expected +// for ALL possible payment statuses. +func TestPaymentRegistrable(t *testing.T) { t.Parallel() - // We will register three HTLC attempts, and always fail the second - // one. We'll generate all combinations of settling/failing the first - // and third HTLC, and assert that the payment status end up as we - // expect. - type testCase struct { - settleFirst bool - settleLast bool + testCases := []struct { + status PaymentStatus + registryErr error + hasSettledHTLC bool + paymentFailed bool + }{ + { + status: StatusInitiated, + registryErr: nil, + }, + { + // Test inflight status with no settled HTLC and no + // failed payment. + status: StatusInFlight, + registryErr: nil, + }, + { + // Test inflight status with settled HTLC but no failed + // payment. + status: StatusInFlight, + registryErr: ErrPaymentPendingSettled, + hasSettledHTLC: true, + }, + { + // Test inflight status with no settled HTLC but failed + // payment. + status: StatusInFlight, + registryErr: ErrPaymentPendingFailed, + paymentFailed: true, + }, + { + // Test error state with settled HTLC and failed + // payment. + status: 0, + registryErr: ErrUnknownPaymentStatus, + hasSettledHTLC: true, + paymentFailed: true, + }, + { + status: StatusSucceeded, + registryErr: ErrPaymentAlreadySucceeded, + }, + { + status: StatusFailed, + registryErr: ErrPaymentAlreadyFailed, + }, + { + status: 0, + registryErr: ErrUnknownPaymentStatus, + }, } - var tests []testCase - for _, f := range []bool{true, false} { - for _, l := range []bool{true, false} { - tests = append(tests, testCase{f, l}) + for i, tc := range testCases { + i, tc := i, tc + + p := &MPPayment{ + Status: tc.status, + State: &MPPaymentState{ + HasSettledHTLC: tc.hasSettledHTLC, + PaymentFailed: tc.paymentFailed, + }, } + + name := fmt.Sprintf("test_%d_%s", i, p.Status.String()) + t.Run(name, func(t *testing.T) { + t.Parallel() + + err := p.Registrable() + require.ErrorIs(t, err, tc.registryErr, + "registrable under state %v", tc.status) + }) } +} - runSubTest := func(t *testing.T, test testCase) { - db, err := MakeTestDB(t) - if err != nil { - t.Fatalf("unable to init db: %v", err) - } +// TestPaymentSetState checks that the method setState creates the +// MPPaymentState as expected. +func TestPaymentSetState(t *testing.T) { + t.Parallel() - paymentDB, err := NewKVPaymentsDB(db) - require.NoError(t, err) + // Create a test preimage and failure reason. + preimage := lntypes.Preimage{1} + failureReasonError := FailureReasonError - info, attempt, preimg, err := genInfo(t) - if err != nil { - t.Fatalf("unable to generate htlc message: %v", err) - } + testCases := []struct { + name string + payment *MPPayment + totalAmt int - // Init the payment, moving it to the StatusInFlight state. - err = paymentDB.InitPayment(info.PaymentIdentifier, info) - if err != nil { - t.Fatalf("unable to send htlc message: %v", err) - } + expectedState *MPPaymentState + errExpected error + }{ + { + // Test that when the sentAmt exceeds totalAmount, the + // error is returned. + name: "amount exceeded error", + // SentAmt returns 90, 10 + // TerminalInfo returns non-nil, nil + // InFlightHTLCs returns 0 + payment: &MPPayment{ + HTLCs: []HTLCAttempt{ + makeSettledAttempt(100, 10, preimage), + }, + }, + totalAmt: 1, + errExpected: ErrSentExceedsTotal, + }, + { + // Test that when the htlc is failed, the fee is not + // used. + name: "fee excluded for failed htlc", + payment: &MPPayment{ + // SentAmt returns 90, 10 + // TerminalInfo returns nil, nil + // InFlightHTLCs returns 1 + HTLCs: []HTLCAttempt{ + makeActiveAttempt(100, 10), + makeFailedAttempt(100, 10), + }, + }, + totalAmt: 1000, + expectedState: &MPPaymentState{ + NumAttemptsInFlight: 1, + RemainingAmt: 1000 - 90, + FeesPaid: 10, + HasSettledHTLC: false, + PaymentFailed: false, + }, + }, + { + // Test when the payment is settled, the state should + // be marked as terminated. + name: "payment settled", + // SentAmt returns 90, 10 + // TerminalInfo returns non-nil, nil + // InFlightHTLCs returns 0 + payment: &MPPayment{ + HTLCs: []HTLCAttempt{ + makeSettledAttempt(100, 10, preimage), + }, + }, + totalAmt: 1000, + expectedState: &MPPaymentState{ + NumAttemptsInFlight: 0, + RemainingAmt: 1000 - 90, + FeesPaid: 10, + HasSettledHTLC: true, + PaymentFailed: false, + }, + }, + { + // Test when the payment is failed, the state should be + // marked as terminated. + name: "payment failed", + // SentAmt returns 0, 0 + // TerminalInfo returns nil, non-nil + // InFlightHTLCs returns 0 + payment: &MPPayment{ + FailureReason: &failureReasonError, + }, + totalAmt: 1000, + expectedState: &MPPaymentState{ + NumAttemptsInFlight: 0, + RemainingAmt: 1000, + FeesPaid: 0, + HasSettledHTLC: false, + PaymentFailed: true, + }, + }, + } - assertPaymentIndex(t, paymentDB, info.PaymentIdentifier) - assertPaymentStatus( - t, paymentDB, info.PaymentIdentifier, StatusInitiated, - ) - assertPaymentInfo( - t, paymentDB, info.PaymentIdentifier, info, nil, nil, - ) + for _, tc := range testCases { + tc := tc - // Create three unique attempts we'll use for the test, and - // register them with the payment control. We set each - // attempts's value to one third of the payment amount, and - // populate the MPP options. - shardAmt := info.Value / 3 - attempt.Route.FinalHop().AmtToForward = shardAmt - attempt.Route.FinalHop().MPP = record.NewMPP( - info.Value, [32]byte{1}, - ) + t.Run(tc.name, func(t *testing.T) { + t.Parallel() - var attempts []*HTLCAttemptInfo - for i := uint64(0); i < 3; i++ { - a := *attempt - a.AttemptID = i - attempts = append(attempts, &a) - - _, err = paymentDB.RegisterAttempt( - info.PaymentIdentifier, &a, - ) - if err != nil { - t.Fatalf("unable to send htlc message: %v", err) - } - assertPaymentStatus( - t, paymentDB, info.PaymentIdentifier, - StatusInFlight, - ) - - htlc := &htlcStatus{ - HTLCAttemptInfo: &a, - } - assertPaymentInfo( - t, paymentDB, info.PaymentIdentifier, info, nil, - htlc, - ) - } - - // For a fourth attempt, check that attempting to - // register it will fail since the total sent amount - // will be too large. - b := *attempt - b.AttemptID = 3 - _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, &b) - require.ErrorIs(t, err, paymentsdb.ErrValueExceedsAmt) - - // Fail the second attempt. - a := attempts[1] - htlcFail := HTLCFailUnreadable - _, err = paymentDB.FailAttempt( - info.PaymentIdentifier, a.AttemptID, - &HTLCFailInfo{ - Reason: htlcFail, - }, - ) - if err != nil { - t.Fatal(err) - } - - htlc := &htlcStatus{ - HTLCAttemptInfo: a, - failure: &htlcFail, - } - assertPaymentInfo( - t, paymentDB, info.PaymentIdentifier, info, nil, htlc, - ) - - // Payment should still be in-flight. - assertPaymentStatus( - t, paymentDB, info.PaymentIdentifier, StatusInFlight, - ) - - // Depending on the test case, settle or fail the first attempt. - a = attempts[0] - htlc = &htlcStatus{ - HTLCAttemptInfo: a, - } - - var firstFailReason *FailureReason - if test.settleFirst { - _, err := paymentDB.SettleAttempt( - info.PaymentIdentifier, a.AttemptID, - &HTLCSettleInfo{ - Preimage: preimg, - }, - ) - if err != nil { - t.Fatalf("error shouldn't have been "+ - "received, got: %v", err) - } - - // Assert that the HTLC has had the preimage recorded. - htlc.settle = &preimg - assertPaymentInfo( - t, paymentDB, info.PaymentIdentifier, info, nil, - htlc, - ) - } else { - _, err := paymentDB.FailAttempt( - info.PaymentIdentifier, a.AttemptID, - &HTLCFailInfo{ - Reason: htlcFail, - }, - ) - if err != nil { - t.Fatalf("error shouldn't have been "+ - "received, got: %v", err) - } - - // Assert the failure was recorded. - htlc.failure = &htlcFail - assertPaymentInfo( - t, paymentDB, info.PaymentIdentifier, info, nil, - htlc, - ) - - // We also record a payment level fail, to move it into - // a terminal state. - failReason := FailureReasonNoRoute - _, err = paymentDB.Fail( - info.PaymentIdentifier, failReason, - ) - if err != nil { - t.Fatalf("unable to fail payment hash: %v", err) + // Attach the payment info. + info := &PaymentCreationInfo{ + Value: lnwire.MilliSatoshi(tc.totalAmt), } + tc.payment.Info = info - // Record the reason we failed the payment, such that - // we can assert this later in the test. - firstFailReason = &failReason - - // The payment is now considered pending fail, since - // there is still an active HTLC. - assertPaymentStatus( - t, paymentDB, info.PaymentIdentifier, - StatusInFlight, - ) - } + // Call the method that updates the payment state. + err := tc.payment.setState() + require.ErrorIs(t, err, tc.errExpected) - // Try to register yet another attempt. This should fail now - // that the payment has reached a terminal condition. - b = *attempt - b.AttemptID = 3 - _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, &b) - if test.settleFirst { - require.ErrorIs( - t, err, paymentsdb.ErrPaymentPendingSettled, - ) - } else { - require.ErrorIs( - t, err, paymentsdb.ErrPaymentPendingFailed, + require.Equal( + t, tc.expectedState, tc.payment.State, + "state not updated as expected", ) - } - - assertPaymentStatus( - t, paymentDB, info.PaymentIdentifier, StatusInFlight, - ) - - // Settle or fail the remaining attempt based on the testcase. - a = attempts[2] - htlc = &htlcStatus{ - HTLCAttemptInfo: a, - } - if test.settleLast { - // Settle the last outstanding attempt. - _, err = paymentDB.SettleAttempt( - info.PaymentIdentifier, a.AttemptID, - &HTLCSettleInfo{ - Preimage: preimg, - }, - ) - require.NoError(t, err, "unable to settle") - - htlc.settle = &preimg - assertPaymentInfo( - t, paymentDB, info.PaymentIdentifier, - info, firstFailReason, htlc, - ) - } else { - // Fail the attempt. - _, err := paymentDB.FailAttempt( - info.PaymentIdentifier, a.AttemptID, - &HTLCFailInfo{ - Reason: htlcFail, - }, - ) - if err != nil { - t.Fatalf("error shouldn't have been "+ - "received, got: %v", err) - } - - // Assert the failure was recorded. - htlc.failure = &htlcFail - assertPaymentInfo( - t, paymentDB, info.PaymentIdentifier, info, - firstFailReason, htlc, - ) - - // Check that we can override any perevious terminal - // failure. This is to allow multiple concurrent shard - // write a terminal failure to the database without - // syncing. - failReason := FailureReasonPaymentDetails - _, err = paymentDB.Fail( - info.PaymentIdentifier, failReason, - ) - require.NoError(t, err, "unable to fail") - } - - var ( - finalStatus PaymentStatus - registerErr error - ) - - switch { - // If one of the attempts settled but the other failed with - // terminal error, we would still consider the payment is - // settled. - case test.settleFirst && !test.settleLast: - finalStatus = StatusSucceeded - registerErr = paymentsdb.ErrPaymentAlreadySucceeded - - case !test.settleFirst && test.settleLast: - finalStatus = StatusSucceeded - registerErr = paymentsdb.ErrPaymentAlreadySucceeded - - // If both failed, we end up in a failed status. - case !test.settleFirst && !test.settleLast: - finalStatus = StatusFailed - registerErr = paymentsdb.ErrPaymentAlreadyFailed - - // Otherwise, the payment has a succeed status. - case test.settleFirst && test.settleLast: - finalStatus = StatusSucceeded - registerErr = paymentsdb.ErrPaymentAlreadySucceeded - } - - assertPaymentStatus( - t, paymentDB, info.PaymentIdentifier, finalStatus, - ) - - // Finally assert we cannot register more attempts. - _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, &b) - require.Equal(t, registerErr, err) - } - - for _, test := range tests { - test := test - subTest := fmt.Sprintf("first=%v, second=%v", - test.settleFirst, test.settleLast) - - t.Run(subTest, func(t *testing.T) { - runSubTest(t, test) }) } } -func TestKVPaymentsDBMPPRecordValidation(t *testing.T) { +// TestNeedWaitAttempts checks whether we need to wait for the results of the +// HTLC attempts against ALL possible payment statuses. +func TestNeedWaitAttempts(t *testing.T) { t.Parallel() - db, err := MakeTestDB(t) - require.NoError(t, err, "unable to init db") - - paymentDB, err := NewKVPaymentsDB(db) - require.NoError(t, err) - - info, attempt, _, err := genInfo(t) - require.NoError(t, err, "unable to generate htlc message") - - // Init the payment. - err = paymentDB.InitPayment(info.PaymentIdentifier, info) - require.NoError(t, err, "unable to send htlc message") - - // Create three unique attempts we'll use for the test, and - // register them with the payment control. We set each - // attempts's value to one third of the payment amount, and - // populate the MPP options. - shardAmt := info.Value / 3 - attempt.Route.FinalHop().AmtToForward = shardAmt - attempt.Route.FinalHop().MPP = record.NewMPP( - info.Value, [32]byte{1}, - ) - - _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, attempt) - require.NoError(t, err, "unable to send htlc message") - - // Now try to register a non-MPP attempt, which should fail. - b := *attempt - b.AttemptID = 1 - b.Route.FinalHop().MPP = nil - _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, &b) - require.ErrorIs(t, err, paymentsdb.ErrMPPayment) - - // Try to register attempt one with a different payment address. - b.Route.FinalHop().MPP = record.NewMPP( - info.Value, [32]byte{2}, - ) - _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, &b) - require.ErrorIs(t, err, paymentsdb.ErrMPPPaymentAddrMismatch) - - // Try registering one with a different total amount. - b.Route.FinalHop().MPP = record.NewMPP( - info.Value/2, [32]byte{1}, - ) - _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, &b) - require.ErrorIs(t, err, paymentsdb.ErrMPPTotalAmountMismatch) - - // Create and init a new payment. This time we'll check that we cannot - // register an MPP attempt if we already registered a non-MPP one. - info, attempt, _, err = genInfo(t) - require.NoError(t, err, "unable to generate htlc message") - - err = paymentDB.InitPayment(info.PaymentIdentifier, info) - require.NoError(t, err, "unable to send htlc message") - - attempt.Route.FinalHop().MPP = nil - _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, attempt) - require.NoError(t, err, "unable to send htlc message") - - // Attempt to register an MPP attempt, which should fail. - b = *attempt - b.AttemptID = 1 - b.Route.FinalHop().MPP = record.NewMPP( - info.Value, [32]byte{1}, - ) - - _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, &b) - require.ErrorIs(t, err, paymentsdb.ErrNonMPPayment) -} - -// TestDeleteFailedAttempts checks that DeleteFailedAttempts properly removes -// failed HTLCs from finished payments. -func TestDeleteFailedAttempts(t *testing.T) { - t.Parallel() - - t.Run("keep failed payment attempts", func(t *testing.T) { - testDeleteFailedAttempts(t, true) - }) - t.Run("remove failed payment attempts", func(t *testing.T) { - testDeleteFailedAttempts(t, false) - }) -} - -func testDeleteFailedAttempts(t *testing.T, keepFailedPaymentAttempts bool) { - db, err := MakeTestDB(t) - require.NoError(t, err, "unable to init db") - - paymentDB, err := NewKVPaymentsDB( - db, - paymentsdb.WithKeepFailedPaymentAttempts( - keepFailedPaymentAttempts, - ), - ) - require.NoError(t, err) - - // Register three payments: - // All payments will have one failed HTLC attempt and one HTLC attempt - // according to its final status. - // 1. A payment with two failed attempts. - // 2. A payment with one failed and one in-flight attempt. - // 3. A payment with one failed and one settled attempt. - - // Initiate payments, which is a slice of payment that is used as - // template to create the corresponding test payments in the database. - // - // Note: The payment id and number of htlc attempts of each payment will - // be added to this slice when creating the payments below. - // This allows the slice to be used directly for testing purposes. - payments := []*payment{ - {status: StatusFailed}, - {status: StatusInFlight}, - {status: StatusSucceeded}, - } - - // Use helper function to register the test payments in the data and - // populate the data to the payments slice. - createTestPayments(t, paymentDB, payments) - - // Check that all payments are there as we added them. - assertPayments(t, paymentDB, payments) - - // Calling DeleteFailedAttempts on a failed payment should delete all - // HTLCs. - require.NoError(t, paymentDB.DeleteFailedAttempts(payments[0].id)) - - // Expect all HTLCs to be deleted if the config is set to delete them. - if !keepFailedPaymentAttempts { - payments[0].htlcs = 0 - } - assertPayments(t, paymentDB, payments) - - // Calling DeleteFailedAttempts on an in-flight payment should return - // an error. - if keepFailedPaymentAttempts { - require.NoError( - t, paymentDB.DeleteFailedAttempts(payments[1].id), - ) - } else { - require.Error(t, paymentDB.DeleteFailedAttempts(payments[1].id)) - } - - // Since DeleteFailedAttempts returned an error, we should expect the - // payment to be unchanged. - assertPayments(t, paymentDB, payments) - - // Cleaning up a successful payment should remove failed htlcs. - require.NoError(t, paymentDB.DeleteFailedAttempts(payments[2].id)) - // Expect all HTLCs except for the settled one to be deleted if the - // config is set to delete them. - if !keepFailedPaymentAttempts { - payments[2].htlcs = 1 - } - assertPayments(t, paymentDB, payments) + testCases := []struct { + status PaymentStatus + remainingAmt lnwire.MilliSatoshi + hasSettledHTLC bool + hasFailureReason bool + needWait bool + expectedErr error + }{ + { + // For a newly created payment we don't need to wait + // for results. + status: StatusInitiated, + remainingAmt: 1000, + needWait: false, + expectedErr: nil, + }, + { + // With HTLCs inflight we don't need to wait when the + // remainingAmt is not zero and we have no settled + // HTLCs. + status: StatusInFlight, + remainingAmt: 1000, + needWait: false, + expectedErr: nil, + }, + { + // With HTLCs inflight we need to wait when the + // remainingAmt is not zero but we have settled HTLCs. + status: StatusInFlight, + remainingAmt: 1000, + hasSettledHTLC: true, + needWait: true, + expectedErr: nil, + }, + { + // With HTLCs inflight we need to wait when the + // remainingAmt is not zero and the payment is failed. + status: StatusInFlight, + remainingAmt: 1000, + needWait: true, + hasFailureReason: true, + expectedErr: nil, + }, - if keepFailedPaymentAttempts { - // DeleteFailedAttempts is ignored, even for non-existent - // payments, if the control tower is configured to keep failed - // HTLCs. - require.NoError( - t, paymentDB.DeleteFailedAttempts(lntypes.ZeroHash), - ) - } else { - // Attempting to cleanup a non-existent payment returns an error. - require.Error( - t, paymentDB.DeleteFailedAttempts(lntypes.ZeroHash), - ) + { + // With the payment settled, but the remainingAmt is + // not zero, we have an error state. + status: StatusSucceeded, + remainingAmt: 1000, + needWait: false, + expectedErr: ErrPaymentInternal, + }, + { + // Payment is in terminal state, no need to wait. + status: StatusFailed, + remainingAmt: 1000, + needWait: false, + expectedErr: nil, + }, + { + // A newly created payment with zero remainingAmt + // indicates an error. + status: StatusInitiated, + remainingAmt: 0, + needWait: false, + expectedErr: ErrPaymentInternal, + }, + { + // With zero remainingAmt we must wait for the results. + status: StatusInFlight, + remainingAmt: 0, + needWait: true, + expectedErr: nil, + }, + { + // Payment is terminated, no need to wait for results. + status: StatusSucceeded, + remainingAmt: 0, + needWait: false, + expectedErr: nil, + }, + { + // Payment is terminated, no need to wait for results. + status: StatusFailed, + remainingAmt: 0, + needWait: false, + expectedErr: ErrPaymentInternal, + }, + { + // Payment is in an unknown status, return an error. + status: 0, + remainingAmt: 0, + needWait: false, + expectedErr: ErrUnknownPaymentStatus, + }, + { + // Payment is in an unknown status, return an error. + status: 0, + remainingAmt: 1000, + needWait: false, + expectedErr: ErrUnknownPaymentStatus, + }, } -} -// assertPaymentStatus retrieves the status of the payment referred to by hash -// and compares it with the expected state. -func assertPaymentStatus(t *testing.T, p *KVPaymentsDB, - hash lntypes.Hash, expStatus PaymentStatus) { + for _, tc := range testCases { + tc := tc - t.Helper() + p := &MPPayment{ + Info: &PaymentCreationInfo{ + PaymentIdentifier: [32]byte{1, 2, 3}, + }, + Status: tc.status, + State: &MPPaymentState{ + RemainingAmt: tc.remainingAmt, + HasSettledHTLC: tc.hasSettledHTLC, + PaymentFailed: tc.hasFailureReason, + }, + } - payment, err := p.FetchPayment(hash) - if errors.Is(err, paymentsdb.ErrPaymentNotInitiated) { - return - } - if err != nil { - t.Fatal(err) - } + name := fmt.Sprintf("status=%s|remainingAmt=%v|"+ + "settledHTLC=%v|failureReason=%v", tc.status, + tc.remainingAmt, tc.hasSettledHTLC, tc.hasFailureReason) - if payment.Status != expStatus { - t.Fatalf("payment status mismatch: expected %v, got %v", - expStatus, payment.Status) - } -} + t.Run(name, func(t *testing.T) { + t.Parallel() -type htlcStatus struct { - *HTLCAttemptInfo - settle *lntypes.Preimage - failure *HTLCFailReason + result, err := p.NeedWaitAttempts() + require.ErrorIs(t, err, tc.expectedErr) + require.Equalf(t, tc.needWait, result, "status=%v, "+ + "remainingAmt=%v", tc.status, tc.remainingAmt) + }) + } } -// assertPaymentInfo retrieves the payment referred to by hash and verifies the -// expected values. -func assertPaymentInfo(t *testing.T, p *KVPaymentsDB, hash lntypes.Hash, - c *PaymentCreationInfo, f *FailureReason, a *htlcStatus) { - - t.Helper() +// TestAllowMoreAttempts checks whether more attempts can be created against +// ALL possible payment statuses. +func TestAllowMoreAttempts(t *testing.T) { + t.Parallel() - payment, err := p.FetchPayment(hash) - if err != nil { - t.Fatal(err) + testCases := []struct { + status PaymentStatus + remainingAmt lnwire.MilliSatoshi + hasSettledHTLC bool + paymentFailed bool + allowMore bool + expectedErr error + }{ + { + // A newly created payment with zero remainingAmt + // indicates an error. + status: StatusInitiated, + remainingAmt: 0, + allowMore: false, + expectedErr: ErrPaymentInternal, + }, + { + // With zero remainingAmt we don't allow more HTLC + // attempts. + status: StatusInFlight, + remainingAmt: 0, + allowMore: false, + expectedErr: nil, + }, + { + // With zero remainingAmt we don't allow more HTLC + // attempts. + status: StatusSucceeded, + remainingAmt: 0, + allowMore: false, + expectedErr: nil, + }, + { + // With zero remainingAmt we don't allow more HTLC + // attempts. + status: StatusFailed, + remainingAmt: 0, + allowMore: false, + expectedErr: nil, + }, + { + // With zero remainingAmt and settled HTLCs we don't + // allow more HTLC attempts. + status: StatusInFlight, + remainingAmt: 0, + hasSettledHTLC: true, + allowMore: false, + expectedErr: nil, + }, + { + // With zero remainingAmt and failed payment we don't + // allow more HTLC attempts. + status: StatusInFlight, + remainingAmt: 0, + paymentFailed: true, + allowMore: false, + expectedErr: nil, + }, + { + // With zero remainingAmt and both settled HTLCs and + // failed payment, we don't allow more HTLC attempts. + status: StatusInFlight, + remainingAmt: 0, + hasSettledHTLC: true, + paymentFailed: true, + allowMore: false, + expectedErr: nil, + }, + { + // A newly created payment can have more attempts. + status: StatusInitiated, + remainingAmt: 1000, + allowMore: true, + expectedErr: nil, + }, + { + // With HTLCs inflight we can have more attempts when + // the remainingAmt is not zero and we have neither + // failed payment or settled HTLCs. + status: StatusInFlight, + remainingAmt: 1000, + allowMore: true, + expectedErr: nil, + }, + { + // With HTLCs inflight we cannot have more attempts + // though the remainingAmt is not zero but we have + // settled HTLCs. + status: StatusInFlight, + remainingAmt: 1000, + hasSettledHTLC: true, + allowMore: false, + expectedErr: nil, + }, + { + // With HTLCs inflight we cannot have more attempts + // though the remainingAmt is not zero but we have + // failed payment. + status: StatusInFlight, + remainingAmt: 1000, + paymentFailed: true, + allowMore: false, + expectedErr: nil, + }, + { + // With HTLCs inflight we cannot have more attempts + // though the remainingAmt is not zero but we have + // settled HTLCs and failed payment. + status: StatusInFlight, + remainingAmt: 1000, + hasSettledHTLC: true, + paymentFailed: true, + allowMore: false, + expectedErr: nil, + }, + { + // With the payment settled, but the remainingAmt is + // not zero, we have an error state. + status: StatusSucceeded, + remainingAmt: 1000, + hasSettledHTLC: true, + allowMore: false, + expectedErr: ErrPaymentInternal, + }, + { + // With the payment failed with no inflight HTLCs, we + // don't allow more attempts to be made. + status: StatusFailed, + remainingAmt: 1000, + paymentFailed: true, + allowMore: false, + expectedErr: nil, + }, + { + // With the payment in an unknown state, we don't allow + // more attempts to be made. + status: 0, + remainingAmt: 1000, + allowMore: false, + expectedErr: nil, + }, } - if !reflect.DeepEqual(payment.Info, c) { - t.Fatalf("PaymentCreationInfos don't match: %v vs %v", - spew.Sdump(payment.Info), spew.Sdump(c)) - } + for i, tc := range testCases { + tc := tc - if f != nil { - if *payment.FailureReason != *f { - t.Fatal("unexpected failure reason") - } - } else { - if payment.FailureReason != nil { - t.Fatal("unexpected failure reason") + p := &MPPayment{ + Info: &PaymentCreationInfo{ + PaymentIdentifier: [32]byte{1, 2, 3}, + }, + Status: tc.status, + State: &MPPaymentState{ + RemainingAmt: tc.remainingAmt, + HasSettledHTLC: tc.hasSettledHTLC, + PaymentFailed: tc.paymentFailed, + }, } - } - if a == nil { - if len(payment.HTLCs) > 0 { - t.Fatal("expected no htlcs") - } - return - } + name := fmt.Sprintf("test_%d|status=%s|remainingAmt=%v", i, + tc.status, tc.remainingAmt) - htlc := payment.HTLCs[a.AttemptID] - if err := assertRouteEqual(&htlc.Route, &a.Route); err != nil { - t.Fatal("routes do not match") + t.Run(name, func(t *testing.T) { + t.Parallel() + + result, err := p.AllowMoreAttempts() + require.ErrorIs(t, err, tc.expectedErr) + require.Equalf(t, tc.allowMore, result, "status=%v, "+ + "remainingAmt=%v", tc.status, tc.remainingAmt) + }) } +} - if htlc.AttemptID != a.AttemptID { - t.Fatalf("unnexpected attempt ID %v, expected %v", - htlc.AttemptID, a.AttemptID) +func makeActiveAttempt(total, fee int) HTLCAttempt { + return HTLCAttempt{ + HTLCAttemptInfo: makeAttemptInfo(total, total-fee), } +} - if a.failure != nil { - if htlc.Failure == nil { - t.Fatalf("expected HTLC to be failed") - } +func makeSettledAttempt(total, fee int, + preimage lntypes.Preimage) HTLCAttempt { - if htlc.Failure.Reason != *a.failure { - t.Fatalf("expected HTLC failure %v, had %v", - *a.failure, htlc.Failure.Reason) - } - } else if htlc.Failure != nil { - t.Fatalf("expected no HTLC failure") + return HTLCAttempt{ + HTLCAttemptInfo: makeAttemptInfo(total, total-fee), + Settle: &HTLCSettleInfo{Preimage: preimage}, } +} - if a.settle != nil { - if htlc.Settle.Preimage != *a.settle { - t.Fatalf("Preimages don't match: %x vs %x", - htlc.Settle.Preimage, a.settle) - } - } else if htlc.Settle != nil { - t.Fatal("expected no settle info") +func makeFailedAttempt(total, fee int) HTLCAttempt { + return HTLCAttempt{ + HTLCAttemptInfo: makeAttemptInfo(total, total-fee), + Failure: &HTLCFailInfo{ + Reason: HTLCFailInternal, + }, } } -// fetchPaymentIndexEntry gets the payment hash for the sequence number provided -// from our payment indexes bucket. -func fetchPaymentIndexEntry(_ *testing.T, p *KVPaymentsDB, - sequenceNumber uint64) (*lntypes.Hash, error) { +func makeAttemptInfo(total, amtForwarded int) HTLCAttemptInfo { + hop := &route.Hop{AmtToForward: lnwire.MilliSatoshi(amtForwarded)} + return HTLCAttemptInfo{ + Route: route.Route{ + TotalAmount: lnwire.MilliSatoshi(total), + Hops: []*route.Hop{hop}, + }, + } +} - var hash lntypes.Hash +// TestEmptyRoutesGenerateSphinxPacket tests that the generateSphinxPacket +// function is able to gracefully handle being passed a nil set of hops for the +// route by the caller. +func TestEmptyRoutesGenerateSphinxPacket(t *testing.T) { + t.Parallel() - if err := kvdb.View(p.db, func(tx walletdb.ReadTx) error { - indexBucket := tx.ReadBucket(paymentsIndexBucket) - key := make([]byte, 8) - byteOrder.PutUint64(key, sequenceNumber) + sessionKey, _ := btcec.NewPrivateKey() + emptyRoute := &route.Route{} + _, _, err := generateSphinxPacket(emptyRoute, testHash[:], sessionKey) + require.ErrorIs(t, err, route.ErrNoRouteHopsProvided) +} - indexValue := indexBucket.Get(key) - if indexValue == nil { - return paymentsdb.ErrNoSequenceNrIndex - } +// TestSuccessesWithoutInFlight tests that the payment control will disallow +// calls to Success when no payment is in flight. +func TestSuccessesWithoutInFlight(t *testing.T) { + t.Parallel() - r := bytes.NewReader(indexValue) + paymentDB := NewTestDB(t) - var err error - hash, err = deserializePaymentIndex(r) - return err - }, func() { - hash = lntypes.Hash{} - }); err != nil { - return nil, err - } + info, _, preimg, err := genInfo(t) + require.NoError(t, err, "unable to generate htlc message") - return &hash, nil + // Attempt to complete the payment should fail. + _, err = paymentDB.SettleAttempt( + info.PaymentIdentifier, 0, + &HTLCSettleInfo{ + Preimage: preimg, + }, + ) + require.ErrorIs(t, err, ErrPaymentNotInitiated) } -// assertPaymentIndex looks up the index for a payment in the db and checks -// that its payment hash matches the expected hash passed in. -func assertPaymentIndex(t *testing.T, p *KVPaymentsDB, - expectedHash lntypes.Hash) { - - // Lookup the payment so that we have its sequence number and check - // that is has correctly been indexed in the payment indexes bucket. - pmt, err := p.FetchPayment(expectedHash) - require.NoError(t, err) +// TestFailsWithoutInFlight checks that a strict payment control will disallow +// calls to Fail when no payment is in flight. +func TestFailsWithoutInFlight(t *testing.T) { + t.Parallel() - hash, err := fetchPaymentIndexEntry(t, p, pmt.SequenceNum) - require.NoError(t, err) - assert.Equal(t, expectedHash, *hash) -} + paymentDB := NewTestDB(t) -// assertNoIndex checks that an index for the sequence number provided does not -// exist. -func assertNoIndex(t *testing.T, p *KVPaymentsDB, seqNr uint64) { - _, err := fetchPaymentIndexEntry(t, p, seqNr) - require.Equal(t, paymentsdb.ErrNoSequenceNrIndex, err) -} + info, _, _, err := genInfo(t) + require.NoError(t, err, "unable to generate htlc message") -// payment is a helper structure that holds basic information on a test payment, -// such as the payment id, the status and the total number of HTLCs attempted. -type payment struct { - id lntypes.Hash - status PaymentStatus - htlcs int + // Calling Fail should return an error. + _, err = paymentDB.Fail( + info.PaymentIdentifier, FailureReasonNoRoute, + ) + require.ErrorIs(t, err, ErrPaymentNotInitiated) } -// createTestPayments registers payments depending on the provided statuses in -// the payments slice. Each payment will receive one failed HTLC and another -// HTLC depending on the final status of the payment provided. -func createTestPayments(t *testing.T, p *KVPaymentsDB, payments []*payment) { - attemptID := uint64(0) - - for i := 0; i < len(payments); i++ { - info, attempt, preimg, err := genInfo(t) - require.NoError(t, err, "unable to generate htlc message") - - // Set the payment id accordingly in the payments slice. - payments[i].id = info.PaymentIdentifier +// TestDeletePayments tests that DeletePayments correctly deletes information +// about completed payments from the database. +func TestDeletePayments(t *testing.T) { + t.Parallel() - attempt.AttemptID = attemptID - attemptID++ + paymentDB := NewTestDB(t) - // Init the payment. - err = p.InitPayment(info.PaymentIdentifier, info) - require.NoError(t, err, "unable to send htlc message") + // Register three payments: + // 1. A payment with two failed attempts. + // 2. A payment with one failed and one settled attempt. + // 3. A payment with one failed and one in-flight attempt. + payments := []*payment{ + {status: StatusFailed}, + {status: StatusSucceeded}, + {status: StatusInFlight}, + } - // Register and fail the first attempt for all payments. - _, err = p.RegisterAttempt(info.PaymentIdentifier, attempt) - require.NoError(t, err, "unable to send htlc message") + // Use helper function to register the test payments in the data and + // populate the data to the payments slice. + createTestPayments(t, paymentDB, payments) - htlcFailure := HTLCFailUnreadable - _, err = p.FailAttempt( - info.PaymentIdentifier, attempt.AttemptID, - &HTLCFailInfo{ - Reason: htlcFailure, - }, - ) - require.NoError(t, err, "unable to fail htlc") + // Check that all payments are there as we added them. + assertPayments(t, paymentDB, payments) - // Increase the HTLC counter in the payments slice for the - // failed attempt. - payments[i].htlcs++ + // Delete HTLC attempts for failed payments only. + numPayments, err := paymentDB.DeletePayments(true, true) + require.NoError(t, err) + require.EqualValues(t, 0, numPayments) - // Depending on the test case, fail or succeed the next - // attempt. - attempt.AttemptID = attemptID - attemptID++ + // The failed payment is the only altered one. + payments[0].htlcs = 0 + assertPayments(t, paymentDB, payments) - _, err = p.RegisterAttempt(info.PaymentIdentifier, attempt) - require.NoError(t, err, "unable to send htlc message") + // Delete failed attempts for all payments. + numPayments, err = paymentDB.DeletePayments(false, true) + require.NoError(t, err) + require.EqualValues(t, 0, numPayments) - switch payments[i].status { - // Fail the attempt and the payment overall. - case StatusFailed: - htlcFailure := HTLCFailUnreadable - _, err = p.FailAttempt( - info.PaymentIdentifier, attempt.AttemptID, - &HTLCFailInfo{ - Reason: htlcFailure, - }, - ) - require.NoError(t, err, "unable to fail htlc") + // The failed attempts should be deleted, except for the in-flight + // payment, that shouldn't be altered until it has completed. + payments[1].htlcs = 1 + assertPayments(t, paymentDB, payments) - failReason := FailureReasonNoRoute - _, err = p.Fail(info.PaymentIdentifier, - failReason) - require.NoError(t, err, "unable to fail payment hash") + // Now delete all failed payments. + numPayments, err = paymentDB.DeletePayments(true, false) + require.NoError(t, err) + require.EqualValues(t, 1, numPayments) - // Settle the attempt - case StatusSucceeded: - _, err := p.SettleAttempt( - info.PaymentIdentifier, attempt.AttemptID, - &HTLCSettleInfo{ - Preimage: preimg, - }, - ) - require.NoError(t, err, "no error should have been "+ - "received from settling a htlc attempt") + assertPayments(t, paymentDB, payments[1:]) - // We leave the attempt in-flight by doing nothing. - case StatusInFlight: - } + // Finally delete all completed payments. + numPayments, err = paymentDB.DeletePayments(false, false) + require.NoError(t, err) + require.EqualValues(t, 1, numPayments) - // Increase the HTLC counter in the payments slice for any - // attempt above. - payments[i].htlcs++ - } + assertPayments(t, paymentDB, payments[2:]) } -// assertPayments is a helper function that given a slice of payment and -// indices for the slice asserts that exactly the same payments in the -// slice for the provided indices exist when fetching payments from the -// database. -func assertPayments(t *testing.T, paymentDB *KVPaymentsDB, - payments []*payment) { +// TestSwitchDoubleSend checks the ability of payment control to +// prevent double sending of htlc message, when message is in StatusInFlight. +func TestSwitchDoubleSend(t *testing.T) { + t.Parallel() - t.Helper() + paymentDB := NewTestDB(t) - dbPayments, err := paymentDB.FetchPayments() - require.NoError(t, err, "could not fetch payments from db") + info, attempt, preimg, err := genInfo(t) + require.NoError(t, err, "unable to generate htlc message") - // Make sure that the number of fetched payments is the same - // as expected. - require.Len( - t, dbPayments, len(payments), "unexpected number of payments", - ) + // Sends base htlc message which initiate base status and move it to + // StatusInFlight and verifies that it was changed. + err = paymentDB.InitPayment(info.PaymentIdentifier, info) + require.NoError(t, err, "unable to send htlc message") - // Convert fetched payments of type MPPayment to our helper structure. - p := make([]*payment, len(dbPayments)) - for i, dbPayment := range dbPayments { - p[i] = &payment{ - id: dbPayment.Info.PaymentIdentifier, - status: dbPayment.Status, - htlcs: len(dbPayment.HTLCs), - } - } + assertPaymentIndex(t, paymentDB, info.PaymentIdentifier) + assertPaymentStatus( + t, paymentDB, info.PaymentIdentifier, StatusInitiated, + ) + assertPaymentInfo( + t, paymentDB, info.PaymentIdentifier, info, nil, nil, + ) - // Check that each payment we want to assert exists in the database. - require.Equal(t, payments, p) -} + // Try to initiate double sending of htlc message with the same + // payment hash, should result in error indicating that payment has + // already been sent. + err = paymentDB.InitPayment(info.PaymentIdentifier, info) + require.ErrorIs(t, err, ErrPaymentExists) -func makeFakeInfo(t *testing.T) (*PaymentCreationInfo, *HTLCAttemptInfo) { - var preimg lntypes.Preimage - copy(preimg[:], rev[:]) + // Record an attempt. + _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, attempt) + require.NoError(t, err, "unable to send htlc message") + assertPaymentStatus( + t, paymentDB, info.PaymentIdentifier, StatusInFlight, + ) - hash := preimg.Hash() + htlc := &htlcStatus{ + HTLCAttemptInfo: attempt, + } + assertPaymentInfo( + t, paymentDB, info.PaymentIdentifier, info, nil, htlc, + ) - c := &PaymentCreationInfo{ - PaymentIdentifier: hash, - Value: 1000, - // Use single second precision to avoid false positive test - // failures due to the monotonic time component. - CreationTime: time.Unix(time.Now().Unix(), 0), - PaymentRequest: []byte("test"), + // Sends base htlc message which initiate StatusInFlight. + err = paymentDB.InitPayment(info.PaymentIdentifier, info) + if !errors.Is(err, ErrPaymentInFlight) { + t.Fatalf("payment control wrong behaviour: " + + "double sending must trigger ErrPaymentInFlight error") } - a, err := NewHtlcAttempt( - 44, priv, testRoute, time.Unix(100, 0), &hash, + // After settling, the error should be ErrAlreadyPaid. + _, err = paymentDB.SettleAttempt( + info.PaymentIdentifier, attempt.AttemptID, + &HTLCSettleInfo{ + Preimage: preimg, + }, + ) + require.NoError(t, err, "error shouldn't have been received, got") + assertPaymentStatus( + t, paymentDB, info.PaymentIdentifier, StatusSucceeded, + ) + + htlc.settle = &preimg + assertPaymentInfo( + t, paymentDB, info.PaymentIdentifier, info, nil, htlc, ) - require.NoError(t, err) - return c, &a.HTLCAttemptInfo + err = paymentDB.InitPayment(info.PaymentIdentifier, info) + if !errors.Is(err, ErrAlreadyPaid) { + t.Fatalf("unable to send htlc message: %v", err) + } } -func TestSentPaymentSerialization(t *testing.T) { +// TestSwitchFail checks that payment status returns to Failed status after +// failing, and that InitPayment allows another HTLC for the same payment hash. +func TestSwitchFail(t *testing.T) { t.Parallel() - c, s := makeFakeInfo(t) + paymentDB := NewTestDB(t) + + info, attempt, preimg, err := genInfo(t) + require.NoError(t, err, "unable to generate htlc message") + + // Sends base htlc message which initiate StatusInFlight. + err = paymentDB.InitPayment(info.PaymentIdentifier, info) + require.NoError(t, err, "unable to send htlc message") + + assertPaymentIndex(t, paymentDB, info.PaymentIdentifier) + assertPaymentStatus( + t, paymentDB, info.PaymentIdentifier, StatusInitiated, + ) + assertPaymentInfo( + t, paymentDB, info.PaymentIdentifier, info, nil, nil, + ) + + // Fail the payment, which should moved it to Failed. + failReason := FailureReasonNoRoute + _, err = paymentDB.Fail(info.PaymentIdentifier, failReason) + require.NoError(t, err, "unable to fail payment hash") - var b bytes.Buffer - require.NoError(t, serializePaymentCreationInfo(&b, c), "serialize") + // Verify the status is indeed Failed. + assertPaymentStatus(t, paymentDB, info.PaymentIdentifier, StatusFailed) + assertPaymentInfo( + t, paymentDB, info.PaymentIdentifier, info, &failReason, nil, + ) - // Assert the length of the serialized creation info is as expected, - // without any custom records. - baseLength := 32 + 8 + 8 + 4 + len(c.PaymentRequest) - require.Len(t, b.Bytes(), baseLength) + // Lookup the payment so we can get its old sequence number before it is + // overwritten. + payment, err := paymentDB.FetchPayment(info.PaymentIdentifier) + require.NoError(t, err) - newCreationInfo, err := deserializePaymentCreationInfo(&b) - require.NoError(t, err, "deserialize") - require.Equal(t, c, newCreationInfo) + // Sends the htlc again, which should succeed since the prior payment + // failed. + err = paymentDB.InitPayment(info.PaymentIdentifier, info) + require.NoError(t, err, "unable to send htlc message") - b.Reset() + // Check that our index has been updated, and the old index has been + // removed. + assertPaymentIndex(t, paymentDB, info.PaymentIdentifier) + assertNoIndex(t, paymentDB, payment.SequenceNum) - // Now we add some custom records to the creation info and serialize it - // again. - c.FirstHopCustomRecords = lnwire.CustomRecords{ - lnwire.MinCustomRecordsTlvType: []byte{1, 2, 3}, - } - require.NoError(t, serializePaymentCreationInfo(&b, c), "serialize") + assertPaymentStatus( + t, paymentDB, info.PaymentIdentifier, StatusInitiated, + ) + assertPaymentInfo( + t, paymentDB, info.PaymentIdentifier, info, nil, nil, + ) - newCreationInfo, err = deserializePaymentCreationInfo(&b) - require.NoError(t, err, "deserialize") - require.Equal(t, c, newCreationInfo) + // Record a new attempt. In this test scenario, the attempt fails. + // However, this is not communicated to control tower in the current + // implementation. It only registers the initiation of the attempt. + _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, attempt) + require.NoError(t, err, "unable to register attempt") - b.Reset() - require.NoError(t, serializeHTLCAttemptInfo(&b, s), "serialize") + htlcReason := HTLCFailUnreadable + _, err = paymentDB.FailAttempt( + info.PaymentIdentifier, attempt.AttemptID, + &HTLCFailInfo{ + Reason: htlcReason, + }, + ) + if err != nil { + t.Fatal(err) + } + assertPaymentStatus( + t, paymentDB, info.PaymentIdentifier, StatusInFlight, + ) - newWireInfo, err := deserializeHTLCAttemptInfo(&b) - require.NoError(t, err, "deserialize") + htlc := &htlcStatus{ + HTLCAttemptInfo: attempt, + failure: &htlcReason, + } - // First we verify all the records match up properly. - require.Equal(t, s.Route, newWireInfo.Route) + assertPaymentInfo(t, paymentDB, info.PaymentIdentifier, info, nil, htlc) - // We now add the new fields and custom records to the route and - // serialize it again. - b.Reset() - s.Route.FirstHopAmount = tlv.NewRecordT[tlv.TlvType0]( - tlv.NewBigSizeT(lnwire.MilliSatoshi(1234)), + // Record another attempt. + attempt.AttemptID = 1 + _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, attempt) + require.NoError(t, err, "unable to send htlc message") + assertPaymentStatus( + t, paymentDB, info.PaymentIdentifier, StatusInFlight, ) - s.Route.FirstHopWireCustomRecords = lnwire.CustomRecords{ - lnwire.MinCustomRecordsTlvType + 3: []byte{4, 5, 6}, + + htlc = &htlcStatus{ + HTLCAttemptInfo: attempt, } - require.NoError(t, serializeHTLCAttemptInfo(&b, s), "serialize") - newWireInfo, err = deserializeHTLCAttemptInfo(&b) - require.NoError(t, err, "deserialize") - require.Equal(t, s.Route, newWireInfo.Route) + assertPaymentInfo( + t, paymentDB, info.PaymentIdentifier, info, nil, htlc, + ) - err = newWireInfo.attachOnionBlobAndCircuit() - require.NoError(t, err) + // Settle the attempt and verify that status was changed to + // StatusSucceeded. + payment, err = paymentDB.SettleAttempt( + info.PaymentIdentifier, attempt.AttemptID, + &HTLCSettleInfo{ + Preimage: preimg, + }, + ) + require.NoError(t, err, "error shouldn't have been received, got") + + if len(payment.HTLCs) != 2 { + t.Fatalf("payment should have two htlcs, got: %d", + len(payment.HTLCs)) + } + + err = assertRouteEqual(&payment.HTLCs[0].Route, &attempt.Route) + if err != nil { + t.Fatalf("unexpected route returned: %v vs %v: %v", + spew.Sdump(attempt.Route), + spew.Sdump(payment.HTLCs[0].Route), err) + } - // Clear routes to allow DeepEqual to compare the remaining fields. - newWireInfo.Route = route.Route{} - s.Route = route.Route{} - newWireInfo.AttemptID = s.AttemptID + assertPaymentStatus( + t, paymentDB, info.PaymentIdentifier, StatusSucceeded, + ) - // Call session key method to set our cached session key so we can use - // DeepEqual, and assert that our key equals the original key. - require.Equal(t, s.cachedSessionKey, newWireInfo.SessionKey()) + htlc.settle = &preimg + assertPaymentInfo( + t, paymentDB, info.PaymentIdentifier, info, nil, htlc, + ) - require.Equal(t, s, newWireInfo) + // Attempt a final payment, which should now fail since the prior + // payment succeed. + err = paymentDB.InitPayment(info.PaymentIdentifier, info) + if !errors.Is(err, ErrAlreadyPaid) { + t.Fatalf("unable to send htlc message: %v", err) + } } -// TestRouteSerialization tests serialization of a regular and blinded route. -func TestRouteSerialization(t *testing.T) { +// TestMultiShard checks the ability of payment control to have multiple in- +// flight HTLCs for a single payment. +func TestMultiShard(t *testing.T) { t.Parallel() - testSerializeRoute(t, testRoute) - testSerializeRoute(t, testBlindedRoute) -} + // We will register three HTLC attempts, and always fail the second + // one. We'll generate all combinations of settling/failing the first + // and third HTLC, and assert that the payment status end up as we + // expect. + type testCase struct { + settleFirst bool + settleLast bool + } -func testSerializeRoute(t *testing.T, route route.Route) { - var b bytes.Buffer - err := SerializeRoute(&b, route) - require.NoError(t, err) + var tests []testCase + for _, f := range []bool{true, false} { + for _, l := range []bool{true, false} { + tests = append(tests, testCase{f, l}) + } + } - r := bytes.NewReader(b.Bytes()) - route2, err := DeserializeRoute(r) - require.NoError(t, err) + runSubTest := func(t *testing.T, test testCase) { + paymentDB := NewTestDB(t) - reflect.DeepEqual(route, route2) -} + info, attempt, preimg, err := genInfo(t) + if err != nil { + t.Fatalf("unable to generate htlc message: %v", err) + } + + // Init the payment, moving it to the StatusInFlight state. + err = paymentDB.InitPayment(info.PaymentIdentifier, info) + if err != nil { + t.Fatalf("unable to send htlc message: %v", err) + } + + assertPaymentIndex(t, paymentDB, info.PaymentIdentifier) + assertPaymentStatus( + t, paymentDB, info.PaymentIdentifier, StatusInitiated, + ) + assertPaymentInfo( + t, paymentDB, info.PaymentIdentifier, info, nil, nil, + ) + + // Create three unique attempts we'll use for the test, and + // register them with the payment control. We set each + // attempts's value to one third of the payment amount, and + // populate the MPP options. + shardAmt := info.Value / 3 + attempt.Route.FinalHop().AmtToForward = shardAmt + attempt.Route.FinalHop().MPP = record.NewMPP( + info.Value, [32]byte{1}, + ) + + var attempts []*HTLCAttemptInfo + for i := uint64(0); i < 3; i++ { + a := *attempt + a.AttemptID = i + attempts = append(attempts, &a) -// deletePayment removes a payment with paymentHash from the payments database. -func deletePayment(t *testing.T, db *DB, paymentHash lntypes.Hash, - seqNr uint64) { + _, err = paymentDB.RegisterAttempt( + info.PaymentIdentifier, &a, + ) + if err != nil { + t.Fatalf("unable to send htlc message: %v", err) + } + assertPaymentStatus( + t, paymentDB, info.PaymentIdentifier, + StatusInFlight, + ) - t.Helper() + htlc := &htlcStatus{ + HTLCAttemptInfo: &a, + } + assertPaymentInfo( + t, paymentDB, info.PaymentIdentifier, info, nil, + htlc, + ) + } - err := kvdb.Update(db, func(tx kvdb.RwTx) error { - payments := tx.ReadWriteBucket(paymentsRootBucket) + // For a fourth attempt, check that attempting to + // register it will fail since the total sent amount + // will be too large. + b := *attempt + b.AttemptID = 3 + _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, &b) + require.ErrorIs(t, err, ErrValueExceedsAmt) - // Delete the payment bucket. - err := payments.DeleteNestedBucket(paymentHash[:]) + // Fail the second attempt. + a := attempts[1] + htlcFail := HTLCFailUnreadable + _, err = paymentDB.FailAttempt( + info.PaymentIdentifier, a.AttemptID, + &HTLCFailInfo{ + Reason: htlcFail, + }, + ) if err != nil { - return err + t.Fatal(err) } - key := make([]byte, 8) - byteOrder.PutUint64(key, seqNr) - - // Delete the index that references this payment. - indexes := tx.ReadWriteBucket(paymentsIndexBucket) - - return indexes.Delete(key) - }, func() {}) - - if err != nil { - t.Fatalf("could not delete "+ - "payment: %v", err) - } -} - -// TestFetchPaymentWithSequenceNumber tests lookup of payments with their -// sequence number. It sets up one payment with no duplicates, and another with -// two duplicates in its duplicates bucket then uses these payments to test the -// case where a specific duplicate is not found and the duplicates bucket is not -// present when we expect it to be. -func TestFetchPaymentWithSequenceNumber(t *testing.T) { - db, err := MakeTestDB(t) - require.NoError(t, err) - - paymentDB, err := NewKVPaymentsDB(db) - require.NoError(t, err) - - // Generate a test payment which does not have duplicates. - noDuplicates, _, _, err := genInfo(t) - require.NoError(t, err) - - // Create a new payment entry in the database. - err = paymentDB.InitPayment( - noDuplicates.PaymentIdentifier, noDuplicates, - ) - require.NoError(t, err) + htlc := &htlcStatus{ + HTLCAttemptInfo: a, + failure: &htlcFail, + } + assertPaymentInfo( + t, paymentDB, info.PaymentIdentifier, info, nil, htlc, + ) - // Fetch the payment so we can get its sequence nr. - noDuplicatesPayment, err := paymentDB.FetchPayment( - noDuplicates.PaymentIdentifier, - ) - require.NoError(t, err) + // Payment should still be in-flight. + assertPaymentStatus( + t, paymentDB, info.PaymentIdentifier, StatusInFlight, + ) - // Generate a test payment which we will add duplicates to. - hasDuplicates, _, preimg, err := genInfo(t) - require.NoError(t, err) + // Depending on the test case, settle or fail the first attempt. + a = attempts[0] + htlc = &htlcStatus{ + HTLCAttemptInfo: a, + } - // Create a new payment entry in the database. - err = paymentDB.InitPayment( - hasDuplicates.PaymentIdentifier, hasDuplicates, - ) - require.NoError(t, err) + var firstFailReason *FailureReason + if test.settleFirst { + _, err := paymentDB.SettleAttempt( + info.PaymentIdentifier, a.AttemptID, + &HTLCSettleInfo{ + Preimage: preimg, + }, + ) + if err != nil { + t.Fatalf("error shouldn't have been "+ + "received, got: %v", err) + } - // Fetch the payment so we can get its sequence nr. - hasDuplicatesPayment, err := paymentDB.FetchPayment( - hasDuplicates.PaymentIdentifier, - ) - require.NoError(t, err) + // Assert that the HTLC has had the preimage recorded. + htlc.settle = &preimg + assertPaymentInfo( + t, paymentDB, info.PaymentIdentifier, info, nil, + htlc, + ) + } else { + _, err := paymentDB.FailAttempt( + info.PaymentIdentifier, a.AttemptID, + &HTLCFailInfo{ + Reason: htlcFail, + }, + ) + if err != nil { + t.Fatalf("error shouldn't have been "+ + "received, got: %v", err) + } - // We declare the sequence numbers used here so that we can reference - // them in tests. - var ( - duplicateOneSeqNr = hasDuplicatesPayment.SequenceNum + 1 - duplicateTwoSeqNr = hasDuplicatesPayment.SequenceNum + 2 - ) + // Assert the failure was recorded. + htlc.failure = &htlcFail + assertPaymentInfo( + t, paymentDB, info.PaymentIdentifier, info, nil, + htlc, + ) - // Add two duplicates to our second payment. - appendDuplicatePayment( - t, db, hasDuplicates.PaymentIdentifier, duplicateOneSeqNr, - preimg, - ) - appendDuplicatePayment( - t, db, hasDuplicates.PaymentIdentifier, duplicateTwoSeqNr, - preimg, - ) + // We also record a payment level fail, to move it into + // a terminal state. + failReason := FailureReasonNoRoute + _, err = paymentDB.Fail( + info.PaymentIdentifier, failReason, + ) + if err != nil { + t.Fatalf("unable to fail payment hash: %v", err) + } - tests := []struct { - name string - paymentHash lntypes.Hash - sequenceNumber uint64 - expectedErr error - }{ - { - name: "lookup payment without duplicates", - paymentHash: noDuplicates.PaymentIdentifier, - sequenceNumber: noDuplicatesPayment.SequenceNum, - expectedErr: nil, - }, - { - name: "lookup payment with duplicates", - paymentHash: hasDuplicates.PaymentIdentifier, - sequenceNumber: hasDuplicatesPayment.SequenceNum, - expectedErr: nil, - }, - { - name: "lookup first duplicate", - paymentHash: hasDuplicates.PaymentIdentifier, - sequenceNumber: duplicateOneSeqNr, - expectedErr: nil, - }, - { - name: "lookup second duplicate", - paymentHash: hasDuplicates.PaymentIdentifier, - sequenceNumber: duplicateTwoSeqNr, - expectedErr: nil, - }, - { - name: "lookup non-existent duplicate", - paymentHash: hasDuplicates.PaymentIdentifier, - sequenceNumber: 999999, - expectedErr: paymentsdb.ErrDuplicateNotFound, - }, - { - name: "lookup duplicate, no duplicates " + - "bucket", - paymentHash: noDuplicates.PaymentIdentifier, - sequenceNumber: duplicateTwoSeqNr, - expectedErr: paymentsdb.ErrNoDuplicateBucket, - }, - } + // Record the reason we failed the payment, such that + // we can assert this later in the test. + firstFailReason = &failReason - for _, test := range tests { - test := test - - t.Run(test.name, func(t *testing.T) { - err := kvdb.Update( - db, func(tx walletdb.ReadWriteTx) error { - var seqNrBytes [8]byte - byteOrder.PutUint64( - seqNrBytes[:], - test.sequenceNumber, - ) - - //nolint:ll - _, err := fetchPaymentWithSequenceNumber( - tx, test.paymentHash, seqNrBytes[:], - ) - - return err - }, func() {}, + // The payment is now considered pending fail, since + // there is still an active HTLC. + assertPaymentStatus( + t, paymentDB, info.PaymentIdentifier, + StatusInFlight, ) - require.Equal(t, test.expectedErr, err) - }) - } -} - -// appendDuplicatePayment adds a duplicate payment to an existing payment. Note -// that this function requires a unique sequence number. -// -// This code is *only* intended to replicate legacy duplicate payments in lnd, -// our current schema does not allow duplicates. -func appendDuplicatePayment(t *testing.T, db kvdb.Backend, - paymentHash lntypes.Hash, seqNr uint64, preImg lntypes.Preimage) { - - err := kvdb.Update(db, func(tx walletdb.ReadWriteTx) error { - bucket, err := fetchPaymentBucketUpdate( - tx, paymentHash, - ) - if err != nil { - return err } - // Create the duplicates bucket if it is not - // present. - dup, err := bucket.CreateBucketIfNotExists( - duplicatePaymentsBucket, - ) - if err != nil { - return err + // Try to register yet another attempt. This should fail now + // that the payment has reached a terminal condition. + b = *attempt + b.AttemptID = 3 + _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, &b) + if test.settleFirst { + require.ErrorIs( + t, err, ErrPaymentPendingSettled, + ) + } else { + require.ErrorIs( + t, err, ErrPaymentPendingFailed, + ) } - var sequenceKey [8]byte - byteOrder.PutUint64(sequenceKey[:], seqNr) + assertPaymentStatus( + t, paymentDB, info.PaymentIdentifier, StatusInFlight, + ) - // Create duplicate payments for the two dup - // sequence numbers we've setup. - putDuplicatePayment(t, dup, sequenceKey[:], paymentHash, preImg) + // Settle or fail the remaining attempt based on the testcase. + a = attempts[2] + htlc = &htlcStatus{ + HTLCAttemptInfo: a, + } + if test.settleLast { + // Settle the last outstanding attempt. + _, err = paymentDB.SettleAttempt( + info.PaymentIdentifier, a.AttemptID, + &HTLCSettleInfo{ + Preimage: preimg, + }, + ) + require.NoError(t, err, "unable to settle") - // Finally, once we have created our entry we add an index for - // it. - err = createPaymentIndexEntry(tx, sequenceKey[:], paymentHash) - require.NoError(t, err) + htlc.settle = &preimg + assertPaymentInfo( + t, paymentDB, info.PaymentIdentifier, + info, firstFailReason, htlc, + ) + } else { + // Fail the attempt. + _, err := paymentDB.FailAttempt( + info.PaymentIdentifier, a.AttemptID, + &HTLCFailInfo{ + Reason: htlcFail, + }, + ) + if err != nil { + t.Fatalf("error shouldn't have been "+ + "received, got: %v", err) + } - return nil - }, func() {}) - require.NoError(t, err, "could not create payment") -} + // Assert the failure was recorded. + htlc.failure = &htlcFail + assertPaymentInfo( + t, paymentDB, info.PaymentIdentifier, info, + firstFailReason, htlc, + ) -// putDuplicatePayment creates a duplicate payment in the duplicates bucket -// provided with the minimal information required for successful reading. -func putDuplicatePayment(t *testing.T, duplicateBucket kvdb.RwBucket, - sequenceKey []byte, paymentHash lntypes.Hash, - preImg lntypes.Preimage) { + // Check that we can override any perevious terminal + // failure. This is to allow multiple concurrent shard + // write a terminal failure to the database without + // syncing. + failReason := FailureReasonPaymentDetails + _, err = paymentDB.Fail( + info.PaymentIdentifier, failReason, + ) + require.NoError(t, err, "unable to fail") + } - paymentBucket, err := duplicateBucket.CreateBucketIfNotExists( - sequenceKey, - ) - require.NoError(t, err) + var ( + finalStatus PaymentStatus + registerErr error + ) - err = paymentBucket.Put(duplicatePaymentSequenceKey, sequenceKey) - require.NoError(t, err) + switch { + // If one of the attempts settled but the other failed with + // terminal error, we would still consider the payment is + // settled. + case test.settleFirst && !test.settleLast: + finalStatus = StatusSucceeded + registerErr = ErrPaymentAlreadySucceeded - // Generate fake information for the duplicate payment. - info, _, _, err := genInfo(t) - require.NoError(t, err) + case !test.settleFirst && test.settleLast: + finalStatus = StatusSucceeded + registerErr = ErrPaymentAlreadySucceeded - // Write the payment info to disk under the creation info key. This code - // is copied rather than using serializePaymentCreationInfo to ensure - // we always write in the legacy format used by duplicate payments. - var b bytes.Buffer - var scratch [8]byte - _, err = b.Write(paymentHash[:]) - require.NoError(t, err) + // If both failed, we end up in a failed status. + case !test.settleFirst && !test.settleLast: + finalStatus = StatusFailed + registerErr = ErrPaymentAlreadyFailed - byteOrder.PutUint64(scratch[:], uint64(info.Value)) - _, err = b.Write(scratch[:]) - require.NoError(t, err) + // Otherwise, the payment has a succeed status. + case test.settleFirst && test.settleLast: + finalStatus = StatusSucceeded + registerErr = ErrPaymentAlreadySucceeded + } - err = serializeTime(&b, info.CreationTime) - require.NoError(t, err) + assertPaymentStatus( + t, paymentDB, info.PaymentIdentifier, finalStatus, + ) - byteOrder.PutUint32(scratch[:4], 0) - _, err = b.Write(scratch[:4]) - require.NoError(t, err) + // Finally assert we cannot register more attempts. + _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, &b) + require.Equal(t, registerErr, err) + } - // Get the PaymentCreationInfo. - err = paymentBucket.Put(duplicatePaymentCreationInfoKey, b.Bytes()) - require.NoError(t, err) + for _, test := range tests { + subTest := fmt.Sprintf("first=%v, second=%v", + test.settleFirst, test.settleLast) - // Duolicate payments are only stored for successes, so add the - // preimage. - err = paymentBucket.Put(duplicatePaymentSettleInfoKey, preImg[:]) - require.NoError(t, err) + t.Run(subTest, func(t *testing.T) { + runSubTest(t, test) + }) + } } diff --git a/payments/db/query.go b/payments/db/query.go new file mode 100644 index 00000000000..5904ab56a4f --- /dev/null +++ b/payments/db/query.go @@ -0,0 +1,69 @@ +package paymentsdb + +// Query represents a query to the payments database starting or ending +// at a certain offset index. The number of retrieved records can be limited. +type Query struct { + // IndexOffset determines the starting point of the payments query and + // is always exclusive. In normal order, the query starts at the next + // higher (available) index compared to IndexOffset. In reversed order, + // the query ends at the next lower (available) index compared to the + // IndexOffset. In the case of a zero index_offset, the query will start + // with the oldest payment when paginating forwards, or will end with + // the most recent payment when paginating backwards. + IndexOffset uint64 + + // MaxPayments is the maximal number of payments returned in the + // payments query. + MaxPayments uint64 + + // Reversed gives a meaning to the IndexOffset. If reversed is set to + // true, the query will fetch payments with indices lower than the + // IndexOffset, otherwise, it will return payments with indices greater + // than the IndexOffset. + Reversed bool + + // If IncludeIncomplete is true, then return payments that have not yet + // fully completed. This means that pending payments, as well as failed + // payments will show up if this field is set to true. + IncludeIncomplete bool + + // CountTotal indicates that all payments currently present in the + // payment index (complete and incomplete) should be counted. + CountTotal bool + + // CreationDateStart, expressed in Unix seconds, if set, filters out + // all payments with a creation date greater than or equal to it. + CreationDateStart int64 + + // CreationDateEnd, expressed in Unix seconds, if set, filters out all + // payments with a creation date less than or equal to it. + CreationDateEnd int64 +} + +// Response contains the result of a query to the payments database. +// It includes the set of payments that match the query and integers which +// represent the index of the first and last item returned in the series of +// payments. These integers allow callers to resume their query in the event +// that the query's response exceeds the max number of returnable events. +type Response struct { + // Payments is the set of payments returned from the database for the + // Query. + Payments []*MPPayment + + // FirstIndexOffset is the index of the first element in the set of + // returned MPPayments. Callers can use this to resume their query + // in the event that the slice has too many events to fit into a single + // response. The offset can be used to continue reverse pagination. + FirstIndexOffset uint64 + + // LastIndexOffset is the index of the last element in the set of + // returned MPPayments. Callers can use this to resume their query + // in the event that the slice has too many events to fit into a single + // response. The offset can be used to continue forward pagination. + LastIndexOffset uint64 + + // TotalCount represents the total number of payments that are currently + // stored in the payment database. This will only be set if the + // CountTotal field in the query was set to true. + TotalCount uint64 +} diff --git a/payments/db/test_kvdb.go b/payments/db/test_kvdb.go new file mode 100644 index 00000000000..d26a1fcf20d --- /dev/null +++ b/payments/db/test_kvdb.go @@ -0,0 +1,40 @@ +package paymentsdb + +import ( + "testing" + + "github.com/lightningnetwork/lnd/kvdb" + "github.com/stretchr/testify/require" +) + +// NewTestDB is a helper function that creates an BBolt database for testing. +func NewTestDB(t *testing.T, opts ...OptionModifier) PaymentDB { + backend, backendCleanup, err := kvdb.GetTestBackend( + t.TempDir(), "kvPaymentDB", + ) + require.NoError(t, err) + + t.Cleanup(backendCleanup) + + paymentDB, err := NewKVPaymentsDB(backend, opts...) + require.NoError(t, err) + + return paymentDB +} + +// NewKVTestDB is a helper function that creates an BBolt database for testing +// and there is no need to convert the interface to the KVPaymentsDB because for +// some unit tests we still need access to the kvdb interface. +func NewKVTestDB(t *testing.T, opts ...OptionModifier) *KVPaymentsDB { + backend, backendCleanup, err := kvdb.GetTestBackend( + t.TempDir(), "kvPaymentDB", + ) + require.NoError(t, err) + + t.Cleanup(backendCleanup) + + paymentDB, err := NewKVPaymentsDB(backend, opts...) + require.NoError(t, err) + + return paymentDB +} diff --git a/routing/control_tower.go b/routing/control_tower.go index b2e4b7bd518..072d8a42b1a 100644 --- a/routing/control_tower.go +++ b/routing/control_tower.go @@ -3,9 +3,9 @@ package routing import ( "sync" - "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/multimutex" + paymentsdb "github.com/lightningnetwork/lnd/payments/db" "github.com/lightningnetwork/lnd/queue" ) @@ -13,23 +13,23 @@ import ( // the payment lifecycle. type DBMPPayment interface { // GetState returns the current state of the payment. - GetState() *channeldb.MPPaymentState + GetState() *paymentsdb.MPPaymentState // Terminated returns true if the payment is in a final state. Terminated() bool // GetStatus returns the current status of the payment. - GetStatus() channeldb.PaymentStatus + GetStatus() paymentsdb.PaymentStatus // NeedWaitAttempts specifies whether the payment needs to wait for the // outcome of an attempt. NeedWaitAttempts() (bool, error) // GetHTLCs returns all HTLCs of this payment. - GetHTLCs() []channeldb.HTLCAttempt + GetHTLCs() []paymentsdb.HTLCAttempt // InFlightHTLCs returns all HTLCs that are in flight. - InFlightHTLCs() []channeldb.HTLCAttempt + InFlightHTLCs() []paymentsdb.HTLCAttempt // AllowMoreAttempts is used to decide whether we can safely attempt // more HTLCs for a given payment state. Return an error if the payment @@ -38,7 +38,7 @@ type DBMPPayment interface { // TerminalInfo returns the settled HTLC attempt or the payment's // failure reason. - TerminalInfo() (*channeldb.HTLCAttempt, *channeldb.FailureReason) + TerminalInfo() (*paymentsdb.HTLCAttempt, *paymentsdb.FailureReason) } // ControlTower tracks all outgoing payments made, whose primary purpose is to @@ -49,7 +49,7 @@ type DBMPPayment interface { type ControlTower interface { // This method checks that no succeeded payment exist for this payment // hash. - InitPayment(lntypes.Hash, *channeldb.PaymentCreationInfo) error + InitPayment(lntypes.Hash, *paymentsdb.PaymentCreationInfo) error // DeleteFailedAttempts removes all failed HTLCs from the db. It should // be called for a given payment whenever all inflight htlcs are @@ -57,7 +57,7 @@ type ControlTower interface { DeleteFailedAttempts(lntypes.Hash) error // RegisterAttempt atomically records the provided HTLCAttemptInfo. - RegisterAttempt(lntypes.Hash, *channeldb.HTLCAttemptInfo) error + RegisterAttempt(lntypes.Hash, *paymentsdb.HTLCAttemptInfo) error // SettleAttempt marks the given attempt settled with the preimage. If // this is a multi shard payment, this might implicitly mean the the @@ -67,12 +67,12 @@ type ControlTower interface { // error to prevent us from making duplicate payments to the same // payment hash. The provided preimage is atomically saved to the DB // for record keeping. - SettleAttempt(lntypes.Hash, uint64, *channeldb.HTLCSettleInfo) ( - *channeldb.HTLCAttempt, error) + SettleAttempt(lntypes.Hash, uint64, *paymentsdb.HTLCSettleInfo) ( + *paymentsdb.HTLCAttempt, error) // FailAttempt marks the given payment attempt failed. - FailAttempt(lntypes.Hash, uint64, *channeldb.HTLCFailInfo) ( - *channeldb.HTLCAttempt, error) + FailAttempt(lntypes.Hash, uint64, *paymentsdb.HTLCFailInfo) ( + *paymentsdb.HTLCAttempt, error) // FetchPayment fetches the payment corresponding to the given payment // hash. @@ -84,10 +84,10 @@ type ControlTower interface { // invoking this method, InitPayment should return nil on its next call // for this payment hash, allowing the user to make a subsequent // payment. - FailPayment(lntypes.Hash, channeldb.FailureReason) error + FailPayment(lntypes.Hash, paymentsdb.FailureReason) error // FetchInFlightPayments returns all payments with status InFlight. - FetchInFlightPayments() ([]*channeldb.MPPayment, error) + FetchInFlightPayments() ([]*paymentsdb.MPPayment, error) // SubscribePayment subscribes to updates for the payment with the given // hash. A first update with the current state of the payment is always @@ -151,7 +151,7 @@ func (s *controlTowerSubscriberImpl) Updates() <-chan interface{} { // controlTower is persistent implementation of ControlTower to restrict // double payment sending. type controlTower struct { - db *channeldb.KVPaymentsDB + db paymentsdb.PaymentDB // subscriberIndex is used to provide a unique id for each subscriber // to all payments. This is used to easily remove the subscriber when @@ -168,7 +168,7 @@ type controlTower struct { } // NewControlTower creates a new instance of the controlTower. -func NewControlTower(db *channeldb.KVPaymentsDB) ControlTower { +func NewControlTower(db paymentsdb.PaymentDB) ControlTower { return &controlTower{ db: db, subscribersAllPayments: make( @@ -184,7 +184,7 @@ func NewControlTower(db *channeldb.KVPaymentsDB) ControlTower { // method returns successfully, the payment is guaranteed to be in the // Initiated state. func (p *controlTower) InitPayment(paymentHash lntypes.Hash, - info *channeldb.PaymentCreationInfo) error { + info *paymentsdb.PaymentCreationInfo) error { err := p.db.InitPayment(paymentHash, info) if err != nil { @@ -215,7 +215,7 @@ func (p *controlTower) DeleteFailedAttempts(paymentHash lntypes.Hash) error { // RegisterAttempt atomically records the provided HTLCAttemptInfo to the // DB. func (p *controlTower) RegisterAttempt(paymentHash lntypes.Hash, - attempt *channeldb.HTLCAttemptInfo) error { + attempt *paymentsdb.HTLCAttemptInfo) error { p.paymentsMtx.Lock(paymentHash) defer p.paymentsMtx.Unlock(paymentHash) @@ -235,8 +235,8 @@ func (p *controlTower) RegisterAttempt(paymentHash lntypes.Hash, // this is a multi shard payment, this might implicitly mean the the // full payment succeeded. func (p *controlTower) SettleAttempt(paymentHash lntypes.Hash, - attemptID uint64, settleInfo *channeldb.HTLCSettleInfo) ( - *channeldb.HTLCAttempt, error) { + attemptID uint64, settleInfo *paymentsdb.HTLCSettleInfo) ( + *paymentsdb.HTLCAttempt, error) { p.paymentsMtx.Lock(paymentHash) defer p.paymentsMtx.Unlock(paymentHash) @@ -254,8 +254,8 @@ func (p *controlTower) SettleAttempt(paymentHash lntypes.Hash, // FailAttempt marks the given payment attempt failed. func (p *controlTower) FailAttempt(paymentHash lntypes.Hash, - attemptID uint64, failInfo *channeldb.HTLCFailInfo) ( - *channeldb.HTLCAttempt, error) { + attemptID uint64, failInfo *paymentsdb.HTLCFailInfo) ( + *paymentsdb.HTLCAttempt, error) { p.paymentsMtx.Lock(paymentHash) defer p.paymentsMtx.Unlock(paymentHash) @@ -286,7 +286,7 @@ func (p *controlTower) FetchPayment(paymentHash lntypes.Hash) ( // NOTE: This method will overwrite the failure reason if the payment is already // failed. func (p *controlTower) FailPayment(paymentHash lntypes.Hash, - reason channeldb.FailureReason) error { + reason paymentsdb.FailureReason) error { p.paymentsMtx.Lock(paymentHash) defer p.paymentsMtx.Unlock(paymentHash) @@ -303,7 +303,9 @@ func (p *controlTower) FailPayment(paymentHash lntypes.Hash, } // FetchInFlightPayments returns all payments with status InFlight. -func (p *controlTower) FetchInFlightPayments() ([]*channeldb.MPPayment, error) { +func (p *controlTower) FetchInFlightPayments() ([]*paymentsdb.MPPayment, + error) { + return p.db.FetchInFlightPayments() } @@ -386,7 +388,7 @@ func (p *controlTower) SubscribeAllPayments() (ControlTowerSubscriber, error) { // be executed atomically (by means of a lock) with the database update to // guarantee consistency of the notifications. func (p *controlTower) notifySubscribers(paymentHash lntypes.Hash, - event *channeldb.MPPayment) { + event *paymentsdb.MPPayment) { // Get all subscribers for this payment. p.subscribersMtx.Lock() diff --git a/routing/control_tower_test.go b/routing/control_tower_test.go index f45c2de385f..3e01065962d 100644 --- a/routing/control_tower_test.go +++ b/routing/control_tower_test.go @@ -50,7 +50,7 @@ func TestControlTowerSubscribeUnknown(t *testing.T) { db := initDB(t) - paymentDB, err := channeldb.NewKVPaymentsDB( + paymentDB, err := paymentsdb.NewKVPaymentsDB( db, paymentsdb.WithKeepFailedPaymentAttempts(true), ) @@ -70,7 +70,7 @@ func TestControlTowerSubscribeSuccess(t *testing.T) { db := initDB(t) - paymentDB, err := channeldb.NewKVPaymentsDB(db) + paymentDB, err := paymentsdb.NewKVPaymentsDB(db) require.NoError(t, err) pControl := NewControlTower(paymentDB) @@ -102,7 +102,7 @@ func TestControlTowerSubscribeSuccess(t *testing.T) { require.NoError(t, err, "expected subscribe to succeed, but got") // Mark the payment as successful. - settleInfo := channeldb.HTLCSettleInfo{ + settleInfo := paymentsdb.HTLCSettleInfo{ Preimage: preimg, } htlcAttempt, err := pControl.SettleAttempt( @@ -126,19 +126,27 @@ func TestControlTowerSubscribeSuccess(t *testing.T) { } for i, s := range subscribers { - var result *channeldb.MPPayment + var result *paymentsdb.MPPayment for result == nil || !result.Terminated() { select { case item := <-s.Updates(): - result = item.(*channeldb.MPPayment) + payment, ok := item.(*paymentsdb.MPPayment) + require.True( + t, ok, "unexpected payment type: %T", + item) + + result = payment + case <-time.After(testTimeout): t.Fatal("timeout waiting for payment result") } } - require.Equalf(t, channeldb.StatusSucceeded, result.GetStatus(), - "subscriber %v failed, want %s, got %s", i, - channeldb.StatusSucceeded, result.GetStatus()) + require.Equalf(t, paymentsdb.StatusSucceeded, + result.GetStatus(), "subscriber %v failed, want %s, "+ + "got %s", i, paymentsdb.StatusSucceeded, + result.GetStatus(), + ) attempt, _ := result.TerminalInfo() if attempt.Settle.Preimage != preimg { @@ -192,7 +200,7 @@ func TestKVPaymentsDBSubscribeAllSuccess(t *testing.T) { db := initDB(t) - paymentDB, err := channeldb.NewKVPaymentsDB( + paymentDB, err := paymentsdb.NewKVPaymentsDB( db, paymentsdb.WithKeepFailedPaymentAttempts(true), ) @@ -228,7 +236,7 @@ func TestKVPaymentsDBSubscribeAllSuccess(t *testing.T) { require.NoError(t, err) // Mark the first payment as successful. - settleInfo1 := channeldb.HTLCSettleInfo{ + settleInfo1 := paymentsdb.HTLCSettleInfo{ Preimage: preimg1, } htlcAttempt1, err := pControl.SettleAttempt( @@ -241,7 +249,7 @@ func TestKVPaymentsDBSubscribeAllSuccess(t *testing.T) { ) // Mark the second payment as successful. - settleInfo2 := channeldb.HTLCSettleInfo{ + settleInfo2 := paymentsdb.HTLCSettleInfo{ Preimage: preimg2, } htlcAttempt2, err := pControl.SettleAttempt( @@ -255,14 +263,20 @@ func TestKVPaymentsDBSubscribeAllSuccess(t *testing.T) { // The two payments will be asserted individually, store the last update // for each payment. - results := make(map[lntypes.Hash]*channeldb.MPPayment) + results := make(map[lntypes.Hash]*paymentsdb.MPPayment) // After exactly 6 updates both payments will/should have completed. for i := 0; i < 6; i++ { select { case item := <-subscription.Updates(): - id := item.(*channeldb.MPPayment).Info.PaymentIdentifier - results[id] = item.(*channeldb.MPPayment) + payment, ok := item.(*paymentsdb.MPPayment) + require.True( + t, ok, "unexpected payment type: %T", + item) + + id := payment.Info.PaymentIdentifier + results[id] = payment + case <-time.After(testTimeout): require.Fail(t, "timeout waiting for payment result") } @@ -270,7 +284,7 @@ func TestKVPaymentsDBSubscribeAllSuccess(t *testing.T) { result1 := results[info1.PaymentIdentifier] require.Equal( - t, channeldb.StatusSucceeded, result1.GetStatus(), + t, paymentsdb.StatusSucceeded, result1.GetStatus(), "unexpected payment state payment 1", ) @@ -288,7 +302,7 @@ func TestKVPaymentsDBSubscribeAllSuccess(t *testing.T) { result2 := results[info2.PaymentIdentifier] require.Equal( - t, channeldb.StatusSucceeded, result2.GetStatus(), + t, paymentsdb.StatusSucceeded, result2.GetStatus(), "unexpected payment state payment 2", ) @@ -311,7 +325,7 @@ func TestKVPaymentsDBSubscribeAllImmediate(t *testing.T) { db := initDB(t) - paymentDB, err := channeldb.NewKVPaymentsDB( + paymentDB, err := paymentsdb.NewKVPaymentsDB( db, paymentsdb.WithKeepFailedPaymentAttempts(true), ) @@ -337,11 +351,17 @@ func TestKVPaymentsDBSubscribeAllImmediate(t *testing.T) { select { case update := <-subscription.Updates(): require.NotNil(t, update) + payment, ok := update.(*paymentsdb.MPPayment) + if !ok { + t.Fatalf("unexpected payment type: %T", update) + } + require.Equal( t, info.PaymentIdentifier, - update.(*channeldb.MPPayment).Info.PaymentIdentifier, + payment.Info.PaymentIdentifier, ) require.Len(t, subscription.Updates(), 0) + case <-time.After(testTimeout): require.Fail(t, "timeout waiting for payment result") } @@ -354,7 +374,7 @@ func TestKVPaymentsDBUnsubscribeSuccess(t *testing.T) { db := initDB(t) - paymentDB, err := channeldb.NewKVPaymentsDB( + paymentDB, err := paymentsdb.NewKVPaymentsDB( db, paymentsdb.WithKeepFailedPaymentAttempts(true), ) @@ -411,8 +431,8 @@ func TestKVPaymentsDBUnsubscribeSuccess(t *testing.T) { subscription2.Close() // Register another update. - failInfo := channeldb.HTLCFailInfo{ - Reason: channeldb.HTLCFailInternal, + failInfo := paymentsdb.HTLCFailInfo{ + Reason: paymentsdb.HTLCFailInternal, } _, err = pControl.FailAttempt( info.PaymentIdentifier, attempt.AttemptID, &failInfo, @@ -429,7 +449,7 @@ func testKVPaymentsDBSubscribeFail(t *testing.T, registerAttempt, db := initDB(t) - paymentDB, err := channeldb.NewKVPaymentsDB( + paymentDB, err := paymentsdb.NewKVPaymentsDB( db, paymentsdb.WithKeepFailedPaymentAttempts( keepFailedPaymentAttempts, @@ -465,8 +485,8 @@ func testKVPaymentsDBSubscribeFail(t *testing.T, registerAttempt, } // Fail the payment attempt. - failInfo := channeldb.HTLCFailInfo{ - Reason: channeldb.HTLCFailInternal, + failInfo := paymentsdb.HTLCFailInfo{ + Reason: paymentsdb.HTLCFailInternal, } htlcAttempt, err := pControl.FailAttempt( info.PaymentIdentifier, attempt.AttemptID, &failInfo, @@ -481,7 +501,7 @@ func testKVPaymentsDBSubscribeFail(t *testing.T, registerAttempt, // Mark the payment as failed. err = pControl.FailPayment( - info.PaymentIdentifier, channeldb.FailureReasonTimeout, + info.PaymentIdentifier, paymentsdb.FailureReasonTimeout, ) if err != nil { t.Fatal(err) @@ -498,17 +518,22 @@ func testKVPaymentsDBSubscribeFail(t *testing.T, registerAttempt, } for i, s := range subscribers { - var result *channeldb.MPPayment + var result *paymentsdb.MPPayment for result == nil || !result.Terminated() { select { case item := <-s.Updates(): - result = item.(*channeldb.MPPayment) + payment, ok := item.(*paymentsdb.MPPayment) + require.True( + t, ok, "unexpected payment type: %T", + item) + + result = payment case <-time.After(testTimeout): t.Fatal("timeout waiting for payment result") } } - if result.GetStatus() == channeldb.StatusSucceeded { + if result.GetStatus() == paymentsdb.StatusSucceeded { t.Fatal("unexpected payment state") } @@ -533,11 +558,11 @@ func testKVPaymentsDBSubscribeFail(t *testing.T, registerAttempt, len(result.HTLCs)) } - require.Equalf(t, channeldb.StatusFailed, result.GetStatus(), + require.Equalf(t, paymentsdb.StatusFailed, result.GetStatus(), "subscriber %v failed, want %s, got %s", i, - channeldb.StatusFailed, result.GetStatus()) + paymentsdb.StatusFailed, result.GetStatus()) - if *result.FailureReason != channeldb.FailureReasonTimeout { + if *result.FailureReason != paymentsdb.FailureReasonTimeout { t.Fatal("unexpected failure reason") } @@ -559,7 +584,7 @@ func initDB(t *testing.T) *channeldb.DB { ) } -func genInfo() (*channeldb.PaymentCreationInfo, *channeldb.HTLCAttemptInfo, +func genInfo() (*paymentsdb.PaymentCreationInfo, *paymentsdb.HTLCAttemptInfo, lntypes.Preimage, error) { preimage, err := genPreimage() @@ -572,14 +597,14 @@ func genInfo() (*channeldb.PaymentCreationInfo, *channeldb.HTLCAttemptInfo, var hash lntypes.Hash copy(hash[:], rhash[:]) - attempt, err := channeldb.NewHtlcAttempt( + attempt, err := paymentsdb.NewHtlcAttempt( 1, priv, testRoute, time.Time{}, &hash, ) if err != nil { return nil, nil, lntypes.Preimage{}, err } - return &channeldb.PaymentCreationInfo{ + return &paymentsdb.PaymentCreationInfo{ PaymentIdentifier: rhash, Value: testRoute.ReceiverAmt(), CreationTime: time.Unix(time.Now().Unix(), 0), diff --git a/routing/missioncontrol.go b/routing/missioncontrol.go index 89e14f523bc..b03724a2e62 100644 --- a/routing/missioncontrol.go +++ b/routing/missioncontrol.go @@ -11,11 +11,11 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btclog/v2" "github.com/btcsuite/btcwallet/walletdb" - "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" + paymentsdb "github.com/lightningnetwork/lnd/payments/db" "github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/tlv" ) @@ -617,7 +617,7 @@ func (m *MissionControl) GetPairHistorySnapshot( // payment attempts need to be made. func (m *MissionControl) ReportPaymentFail(paymentID uint64, rt *route.Route, failureSourceIdx *int, failure lnwire.FailureMessage) ( - *channeldb.FailureReason, error) { + *paymentsdb.FailureReason, error) { timestamp := m.cfg.clock.Now() @@ -649,7 +649,7 @@ func (m *MissionControl) ReportPaymentSuccess(paymentID uint64, // processPaymentResult stores a payment result in the mission control store and // updates mission control's in-memory state. func (m *MissionControl) processPaymentResult(result *paymentResult) ( - *channeldb.FailureReason, error) { + *paymentsdb.FailureReason, error) { // Store complete result in database. m.store.AddResult(result) @@ -667,7 +667,7 @@ func (m *MissionControl) processPaymentResult(result *paymentResult) ( // estimates. It returns a bool indicating whether this error is a final error // and no further payment attempts need to be made. func (m *MissionControl) applyPaymentResult( - result *paymentResult) *channeldb.FailureReason { + result *paymentResult) *paymentsdb.FailureReason { // Interpret result. i := interpretResult(&result.route.Val, result.failure.ValOpt()) diff --git a/routing/mock_test.go b/routing/mock_test.go index d9528254e81..9cb938f74f2 100644 --- a/routing/mock_test.go +++ b/routing/mock_test.go @@ -7,7 +7,6 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcutil" - "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch" @@ -134,12 +133,12 @@ var _ MissionControlQuerier = (*mockMissionControlOld)(nil) func (m *mockMissionControlOld) ReportPaymentFail( paymentID uint64, rt *route.Route, failureSourceIdx *int, failure lnwire.FailureMessage) ( - *channeldb.FailureReason, error) { + *paymentsdb.FailureReason, error) { // Report a permanent failure if this is an error caused // by incorrect details. if failure.Code() == lnwire.CodeIncorrectOrUnknownPaymentDetails { - reason := channeldb.FailureReasonPaymentDetails + reason := paymentsdb.FailureReasonPaymentDetails return &reason, nil } @@ -248,11 +247,11 @@ func (m *mockPayerOld) CleanStore(pids map[uint64]struct{}) error { } type initArgs struct { - c *channeldb.PaymentCreationInfo + c *paymentsdb.PaymentCreationInfo } type registerAttemptArgs struct { - a *channeldb.HTLCAttemptInfo + a *paymentsdb.HTLCAttemptInfo } type settleAttemptArgs struct { @@ -260,22 +259,22 @@ type settleAttemptArgs struct { } type failAttemptArgs struct { - reason *channeldb.HTLCFailInfo + reason *paymentsdb.HTLCFailInfo } type failPaymentArgs struct { - reason channeldb.FailureReason + reason paymentsdb.FailureReason } type testPayment struct { - info channeldb.PaymentCreationInfo - attempts []channeldb.HTLCAttempt + info paymentsdb.PaymentCreationInfo + attempts []paymentsdb.HTLCAttempt } type mockControlTowerOld struct { payments map[lntypes.Hash]*testPayment successful map[lntypes.Hash]struct{} - failed map[lntypes.Hash]channeldb.FailureReason + failed map[lntypes.Hash]paymentsdb.FailureReason init chan initArgs registerAttempt chan registerAttemptArgs @@ -293,12 +292,12 @@ func makeMockControlTower() *mockControlTowerOld { return &mockControlTowerOld{ payments: make(map[lntypes.Hash]*testPayment), successful: make(map[lntypes.Hash]struct{}), - failed: make(map[lntypes.Hash]channeldb.FailureReason), + failed: make(map[lntypes.Hash]paymentsdb.FailureReason), } } func (m *mockControlTowerOld) InitPayment(phash lntypes.Hash, - c *channeldb.PaymentCreationInfo) error { + c *paymentsdb.PaymentCreationInfo) error { if m.init != nil { m.init <- initArgs{c} @@ -355,7 +354,7 @@ func (m *mockControlTowerOld) DeleteFailedAttempts(phash lntypes.Hash) error { } func (m *mockControlTowerOld) RegisterAttempt(phash lntypes.Hash, - a *channeldb.HTLCAttemptInfo) error { + a *paymentsdb.HTLCAttemptInfo) error { if m.registerAttempt != nil { m.registerAttempt <- registerAttemptArgs{a} @@ -400,7 +399,7 @@ func (m *mockControlTowerOld) RegisterAttempt(phash lntypes.Hash, } // Add attempt to payment. - p.attempts = append(p.attempts, channeldb.HTLCAttempt{ + p.attempts = append(p.attempts, paymentsdb.HTLCAttempt{ HTLCAttemptInfo: *a, }) m.payments[phash] = p @@ -409,8 +408,8 @@ func (m *mockControlTowerOld) RegisterAttempt(phash lntypes.Hash, } func (m *mockControlTowerOld) SettleAttempt(phash lntypes.Hash, - pid uint64, settleInfo *channeldb.HTLCSettleInfo) ( - *channeldb.HTLCAttempt, error) { + pid uint64, settleInfo *paymentsdb.HTLCSettleInfo) ( + *paymentsdb.HTLCAttempt, error) { if m.settleAttempt != nil { m.settleAttempt <- settleAttemptArgs{settleInfo.Preimage} @@ -442,7 +441,8 @@ func (m *mockControlTowerOld) SettleAttempt(phash lntypes.Hash, // Mark the payment successful on first settled attempt. m.successful[phash] = struct{}{} - return &channeldb.HTLCAttempt{ + + return &paymentsdb.HTLCAttempt{ Settle: settleInfo, }, nil } @@ -451,7 +451,7 @@ func (m *mockControlTowerOld) SettleAttempt(phash lntypes.Hash, } func (m *mockControlTowerOld) FailAttempt(phash lntypes.Hash, pid uint64, - failInfo *channeldb.HTLCFailInfo) (*channeldb.HTLCAttempt, error) { + failInfo *paymentsdb.HTLCFailInfo) (*paymentsdb.HTLCAttempt, error) { if m.failAttempt != nil { m.failAttempt <- failAttemptArgs{failInfo} @@ -480,7 +480,8 @@ func (m *mockControlTowerOld) FailAttempt(phash lntypes.Hash, pid uint64, } p.attempts[i].Failure = failInfo - return &channeldb.HTLCAttempt{ + + return &paymentsdb.HTLCAttempt{ Failure: failInfo, }, nil } @@ -489,7 +490,7 @@ func (m *mockControlTowerOld) FailAttempt(phash lntypes.Hash, pid uint64, } func (m *mockControlTowerOld) FailPayment(phash lntypes.Hash, - reason channeldb.FailureReason) error { + reason paymentsdb.FailureReason) error { m.Lock() defer m.Unlock() @@ -518,14 +519,14 @@ func (m *mockControlTowerOld) FetchPayment(phash lntypes.Hash) ( } func (m *mockControlTowerOld) fetchPayment(phash lntypes.Hash) ( - *channeldb.MPPayment, error) { + *paymentsdb.MPPayment, error) { p, ok := m.payments[phash] if !ok { return nil, paymentsdb.ErrPaymentNotInitiated } - mp := &channeldb.MPPayment{ + mp := &paymentsdb.MPPayment{ Info: &p.info, } @@ -545,7 +546,7 @@ func (m *mockControlTowerOld) fetchPayment(phash lntypes.Hash) ( } func (m *mockControlTowerOld) FetchInFlightPayments() ( - []*channeldb.MPPayment, error) { + []*paymentsdb.MPPayment, error) { if m.fetchInFlight != nil { m.fetchInFlight <- struct{}{} @@ -555,7 +556,7 @@ func (m *mockControlTowerOld) FetchInFlightPayments() ( defer m.Unlock() // In flight are all payments not successful or failed. - var fl []*channeldb.MPPayment + var fl []*paymentsdb.MPPayment for hash := range m.payments { if _, ok := m.successful[hash]; ok { continue @@ -664,7 +665,7 @@ var _ MissionControlQuerier = (*mockMissionControl)(nil) func (m *mockMissionControl) ReportPaymentFail( paymentID uint64, rt *route.Route, failureSourceIdx *int, failure lnwire.FailureMessage) ( - *channeldb.FailureReason, error) { + *paymentsdb.FailureReason, error) { args := m.Called(paymentID, rt, failureSourceIdx, failure) @@ -673,7 +674,7 @@ func (m *mockMissionControl) ReportPaymentFail( return nil, args.Error(1) } - return args.Get(0).(*channeldb.FailureReason), args.Error(1) + return args.Get(0).(*paymentsdb.FailureReason), args.Error(1) } func (m *mockMissionControl) ReportPaymentSuccess(paymentID uint64, @@ -733,7 +734,7 @@ type mockControlTower struct { var _ ControlTower = (*mockControlTower)(nil) func (m *mockControlTower) InitPayment(phash lntypes.Hash, - c *channeldb.PaymentCreationInfo) error { + c *paymentsdb.PaymentCreationInfo) error { args := m.Called(phash, c) return args.Error(0) @@ -745,15 +746,15 @@ func (m *mockControlTower) DeleteFailedAttempts(phash lntypes.Hash) error { } func (m *mockControlTower) RegisterAttempt(phash lntypes.Hash, - a *channeldb.HTLCAttemptInfo) error { + a *paymentsdb.HTLCAttemptInfo) error { args := m.Called(phash, a) return args.Error(0) } func (m *mockControlTower) SettleAttempt(phash lntypes.Hash, - pid uint64, settleInfo *channeldb.HTLCSettleInfo) ( - *channeldb.HTLCAttempt, error) { + pid uint64, settleInfo *paymentsdb.HTLCSettleInfo) ( + *paymentsdb.HTLCAttempt, error) { args := m.Called(phash, pid, settleInfo) @@ -762,11 +763,11 @@ func (m *mockControlTower) SettleAttempt(phash lntypes.Hash, return nil, args.Error(1) } - return attempt.(*channeldb.HTLCAttempt), args.Error(1) + return attempt.(*paymentsdb.HTLCAttempt), args.Error(1) } func (m *mockControlTower) FailAttempt(phash lntypes.Hash, pid uint64, - failInfo *channeldb.HTLCFailInfo) (*channeldb.HTLCAttempt, error) { + failInfo *paymentsdb.HTLCFailInfo) (*paymentsdb.HTLCAttempt, error) { args := m.Called(phash, pid, failInfo) @@ -775,11 +776,11 @@ func (m *mockControlTower) FailAttempt(phash lntypes.Hash, pid uint64, return nil, args.Error(1) } - return args.Get(0).(*channeldb.HTLCAttempt), args.Error(1) + return attempt.(*paymentsdb.HTLCAttempt), args.Error(1) } func (m *mockControlTower) FailPayment(phash lntypes.Hash, - reason channeldb.FailureReason) error { + reason paymentsdb.FailureReason) error { args := m.Called(phash, reason) return args.Error(0) @@ -800,10 +801,10 @@ func (m *mockControlTower) FetchPayment(phash lntypes.Hash) ( } func (m *mockControlTower) FetchInFlightPayments() ( - []*channeldb.MPPayment, error) { + []*paymentsdb.MPPayment, error) { args := m.Called() - return args.Get(0).([]*channeldb.MPPayment), args.Error(1) + return args.Get(0).([]*paymentsdb.MPPayment), args.Error(1) } func (m *mockControlTower) SubscribePayment(paymentHash lntypes.Hash) ( @@ -826,14 +827,14 @@ type mockMPPayment struct { var _ DBMPPayment = (*mockMPPayment)(nil) -func (m *mockMPPayment) GetState() *channeldb.MPPaymentState { +func (m *mockMPPayment) GetState() *paymentsdb.MPPaymentState { args := m.Called() - return args.Get(0).(*channeldb.MPPaymentState) + return args.Get(0).(*paymentsdb.MPPaymentState) } -func (m *mockMPPayment) GetStatus() channeldb.PaymentStatus { +func (m *mockMPPayment) GetStatus() paymentsdb.PaymentStatus { args := m.Called() - return args.Get(0).(channeldb.PaymentStatus) + return args.Get(0).(paymentsdb.PaymentStatus) } func (m *mockMPPayment) Terminated() bool { @@ -847,14 +848,14 @@ func (m *mockMPPayment) NeedWaitAttempts() (bool, error) { return args.Bool(0), args.Error(1) } -func (m *mockMPPayment) GetHTLCs() []channeldb.HTLCAttempt { +func (m *mockMPPayment) GetHTLCs() []paymentsdb.HTLCAttempt { args := m.Called() - return args.Get(0).([]channeldb.HTLCAttempt) + return args.Get(0).([]paymentsdb.HTLCAttempt) } -func (m *mockMPPayment) InFlightHTLCs() []channeldb.HTLCAttempt { +func (m *mockMPPayment) InFlightHTLCs() []paymentsdb.HTLCAttempt { args := m.Called() - return args.Get(0).([]channeldb.HTLCAttempt) + return args.Get(0).([]paymentsdb.HTLCAttempt) } func (m *mockMPPayment) AllowMoreAttempts() (bool, error) { @@ -862,24 +863,24 @@ func (m *mockMPPayment) AllowMoreAttempts() (bool, error) { return args.Bool(0), args.Error(1) } -func (m *mockMPPayment) TerminalInfo() (*channeldb.HTLCAttempt, - *channeldb.FailureReason) { +func (m *mockMPPayment) TerminalInfo() (*paymentsdb.HTLCAttempt, + *paymentsdb.FailureReason) { args := m.Called() var ( - settleInfo *channeldb.HTLCAttempt - failureInfo *channeldb.FailureReason + settleInfo *paymentsdb.HTLCAttempt + failureInfo *paymentsdb.FailureReason ) settle := args.Get(0) if settle != nil { - settleInfo = settle.(*channeldb.HTLCAttempt) + settleInfo = settle.(*paymentsdb.HTLCAttempt) } reason := args.Get(1) if reason != nil { - failureInfo = reason.(*channeldb.FailureReason) + failureInfo = reason.(*paymentsdb.FailureReason) } return settleInfo, failureInfo diff --git a/routing/payment_lifecycle.go b/routing/payment_lifecycle.go index 7a6443ab66a..245e3b3c151 100644 --- a/routing/payment_lifecycle.go +++ b/routing/payment_lifecycle.go @@ -9,13 +9,13 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/davecgh/go-spew/spew" sphinx "github.com/lightningnetwork/lightning-onion" - "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnutils" "github.com/lightningnetwork/lnd/lnwire" + paymentsdb "github.com/lightningnetwork/lnd/payments/db" "github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/routing/shards" "github.com/lightningnetwork/lnd/tlv" @@ -29,7 +29,7 @@ var ErrPaymentLifecycleExiting = errors.New("payment lifecycle exiting") // HTLC. type switchResult struct { // attempt is the HTLC sent to the switch. - attempt *channeldb.HTLCAttempt + attempt *paymentsdb.HTLCAttempt // result is sent from the switch which contains either a preimage if // ths HTLC is settled or an error if it's failed. @@ -59,7 +59,7 @@ type paymentLifecycle struct { // an HTLC attempt, which is always mounted to `p.collectResultAsync` // except in unit test, where we use a much simpler resultCollector to // decouple the test flow for the payment lifecycle. - resultCollector func(attempt *channeldb.HTLCAttempt) + resultCollector func(attempt *paymentsdb.HTLCAttempt) } // newPaymentLifecycle initiates a new payment lifecycle and returns it. @@ -226,6 +226,19 @@ func (p *paymentLifecycle) resumePayment(ctx context.Context) ([32]byte, // critical error during path finding. lifecycle: for { + // Before we attempt any new shard, we'll check to see if we've + // gone past the payment attempt timeout or if the context was + // canceled. If the context is done, the payment is marked as + // failed and we reload the latest payment state to reflect + // this. + // + // NOTE: This can be called several times if there are more + // attempts to be resolved after the timeout or context is + // cancelled. + if err := p.checkContext(ctx); err != nil { + return exitWithErr(err) + } + // We update the payment state on every iteration. currentPayment, ps, err := p.reloadPayment() if err != nil { @@ -241,19 +254,11 @@ lifecycle: // We now proceed our lifecycle with the following tasks in // order, - // 1. check context. - // 2. request route. - // 3. create HTLC attempt. - // 4. send HTLC attempt. - // 5. collect HTLC attempt result. + // 1. request route. + // 2. create HTLC attempt. + // 3. send HTLC attempt. + // 4. collect HTLC attempt result. // - // Before we attempt any new shard, we'll check to see if we've - // gone past the payment attempt timeout, or if the context was - // cancelled, or the router is exiting. In any of these cases, - // we'll stop this payment attempt short. - if err := p.checkContext(ctx); err != nil { - return exitWithErr(err) - } // Now decide the next step of the current lifecycle. step, err := p.decideNextStep(payment) @@ -347,13 +352,13 @@ func (p *paymentLifecycle) checkContext(ctx context.Context) error { // user-provided timeout was reached, or the context was // canceled, either to a manual cancellation or due to an // unknown error. - var reason channeldb.FailureReason + var reason paymentsdb.FailureReason if errors.Is(ctx.Err(), context.DeadlineExceeded) { - reason = channeldb.FailureReasonTimeout + reason = paymentsdb.FailureReasonTimeout log.Warnf("Payment attempt not completed before "+ "context timeout, id=%s", p.identifier.String()) } else { - reason = channeldb.FailureReasonCanceled + reason = paymentsdb.FailureReasonCanceled log.Warnf("Payment attempt context canceled, id=%s", p.identifier.String()) } @@ -381,7 +386,7 @@ func (p *paymentLifecycle) checkContext(ctx context.Context) error { // requestRoute is responsible for finding a route to be used to create an HTLC // attempt. func (p *paymentLifecycle) requestRoute( - ps *channeldb.MPPaymentState) (*route.Route, error) { + ps *paymentsdb.MPPaymentState) (*route.Route, error) { remainingFees := p.calcFeeBudget(ps.FeesPaid) @@ -450,14 +455,14 @@ type attemptResult struct { err error // attempt is the attempt structure as recorded in the database. - attempt *channeldb.HTLCAttempt + attempt *paymentsdb.HTLCAttempt } // collectResultAsync launches a goroutine that will wait for the result of the // given HTLC attempt to be available then save its result in a map. Once // received, it will send the result returned from the switch to channel // `resultCollected`. -func (p *paymentLifecycle) collectResultAsync(attempt *channeldb.HTLCAttempt) { +func (p *paymentLifecycle) collectResultAsync(attempt *paymentsdb.HTLCAttempt) { log.Debugf("Collecting result for attempt %v in payment %v", attempt.AttemptID, p.identifier) @@ -499,7 +504,7 @@ func (p *paymentLifecycle) collectResultAsync(attempt *channeldb.HTLCAttempt) { // collectResult waits for the result of the given HTLC attempt to be sent by // the switch and returns it. func (p *paymentLifecycle) collectResult( - attempt *channeldb.HTLCAttempt) (*htlcswitch.PaymentResult, error) { + attempt *paymentsdb.HTLCAttempt) (*htlcswitch.PaymentResult, error) { log.Tracef("Collecting result for attempt %v", lnutils.SpewLogClosure(attempt)) @@ -576,7 +581,7 @@ func (p *paymentLifecycle) collectResult( // by using the route info provided. The `remainingAmt` is used to decide // whether this is the last attempt. func (p *paymentLifecycle) registerAttempt(rt *route.Route, - remainingAmt lnwire.MilliSatoshi) (*channeldb.HTLCAttempt, error) { + remainingAmt lnwire.MilliSatoshi) (*paymentsdb.HTLCAttempt, error) { // If this route will consume the last remaining amount to send // to the receiver, this will be our last shard (for now). @@ -603,7 +608,7 @@ func (p *paymentLifecycle) registerAttempt(rt *route.Route, // createNewPaymentAttempt creates a new payment attempt from the given route. func (p *paymentLifecycle) createNewPaymentAttempt(rt *route.Route, - lastShard bool) (*channeldb.HTLCAttempt, error) { + lastShard bool) (*paymentsdb.HTLCAttempt, error) { // Generate a new key to be used for this attempt. sessionKey, err := generateNewSessionKey() @@ -643,7 +648,7 @@ func (p *paymentLifecycle) createNewPaymentAttempt(rt *route.Route, // We now have all the information needed to populate the current // attempt information. - return channeldb.NewHtlcAttempt( + return paymentsdb.NewHtlcAttempt( attemptID, sessionKey, *rt, p.router.cfg.Clock.Now(), &hash, ) } @@ -652,7 +657,7 @@ func (p *paymentLifecycle) createNewPaymentAttempt(rt *route.Route, // the payment. If this attempt fails, then we'll continue on to the next // available route. func (p *paymentLifecycle) sendAttempt( - attempt *channeldb.HTLCAttempt) (*attemptResult, error) { + attempt *paymentsdb.HTLCAttempt) (*attemptResult, error) { log.Debugf("Sending HTLC attempt(id=%v, total_amt=%v, first_hop_amt=%d"+ ") for payment %v", attempt.AttemptID, @@ -789,7 +794,7 @@ func (p *paymentLifecycle) amendFirstHopData(rt *route.Route) error { // failAttemptAndPayment fails both the payment and its attempt via the // router's control tower, which marks the payment as failed in db. func (p *paymentLifecycle) failPaymentAndAttempt( - attemptID uint64, reason *channeldb.FailureReason, + attemptID uint64, reason *paymentsdb.FailureReason, sendErr error) (*attemptResult, error) { log.Errorf("Payment %v failed: final_outcome=%v, raw_err=%v", @@ -818,10 +823,10 @@ func (p *paymentLifecycle) failPaymentAndAttempt( // the error type, the error is either the final outcome of the payment or we // need to continue with an alternative route. A final outcome is indicated by // a non-nil reason value. -func (p *paymentLifecycle) handleSwitchErr(attempt *channeldb.HTLCAttempt, +func (p *paymentLifecycle) handleSwitchErr(attempt *paymentsdb.HTLCAttempt, sendErr error) (*attemptResult, error) { - internalErrorReason := channeldb.FailureReasonError + internalErrorReason := paymentsdb.FailureReasonError attemptID := attempt.AttemptID // reportAndFail is a helper closure that reports the failure to the @@ -1025,34 +1030,34 @@ func (p *paymentLifecycle) failAttempt(attemptID uint64, // marshallError marshall an error as received from the switch to a structure // that is suitable for database storage. -func marshallError(sendError error, time time.Time) *channeldb.HTLCFailInfo { - response := &channeldb.HTLCFailInfo{ +func marshallError(sendError error, time time.Time) *paymentsdb.HTLCFailInfo { + response := &paymentsdb.HTLCFailInfo{ FailTime: time, } switch { case errors.Is(sendError, htlcswitch.ErrPaymentIDNotFound): - response.Reason = channeldb.HTLCFailInternal + response.Reason = paymentsdb.HTLCFailInternal return response case errors.Is(sendError, htlcswitch.ErrUnreadableFailureMessage): - response.Reason = channeldb.HTLCFailUnreadable + response.Reason = paymentsdb.HTLCFailUnreadable return response } var rtErr htlcswitch.ClearTextError ok := errors.As(sendError, &rtErr) if !ok { - response.Reason = channeldb.HTLCFailInternal + response.Reason = paymentsdb.HTLCFailInternal return response } message := rtErr.WireMessage() if message != nil { - response.Reason = channeldb.HTLCFailMessage + response.Reason = paymentsdb.HTLCFailMessage response.Message = message } else { - response.Reason = channeldb.HTLCFailUnknown + response.Reason = paymentsdb.HTLCFailUnknown } // If the ClearTextError received is a ForwardingError, the error @@ -1077,7 +1082,7 @@ func marshallError(sendError error, time time.Time) *channeldb.HTLCFailInfo { // enabled, the `Hash` field in their HTLC attempts is nil. In that case, we use // the payment hash as the `attempt.Hash` as they are identical. func (p *paymentLifecycle) patchLegacyPaymentHash( - a channeldb.HTLCAttempt) channeldb.HTLCAttempt { + a paymentsdb.HTLCAttempt) paymentsdb.HTLCAttempt { // Exit early if this is not a legacy attempt. if a.Hash != nil { @@ -1129,7 +1134,7 @@ func (p *paymentLifecycle) reloadInflightAttempts() (DBMPPayment, error) { // reloadPayment returns the latest payment found in the db (control tower). func (p *paymentLifecycle) reloadPayment() (DBMPPayment, - *channeldb.MPPaymentState, error) { + *paymentsdb.MPPaymentState, error) { // Read the db to get the latest state of the payment. payment, err := p.router.cfg.Control.FetchPayment(p.identifier) @@ -1149,7 +1154,7 @@ func (p *paymentLifecycle) reloadPayment() (DBMPPayment, // handleAttemptResult processes the result of an HTLC attempt returned from // the htlcswitch. -func (p *paymentLifecycle) handleAttemptResult(attempt *channeldb.HTLCAttempt, +func (p *paymentLifecycle) handleAttemptResult(attempt *paymentsdb.HTLCAttempt, result *htlcswitch.PaymentResult) (*attemptResult, error) { // If the result has an error, we need to further process it by failing @@ -1174,7 +1179,7 @@ func (p *paymentLifecycle) handleAttemptResult(attempt *channeldb.HTLCAttempt, // move the shard to the settled state. htlcAttempt, err := p.router.cfg.Control.SettleAttempt( p.identifier, attempt.AttemptID, - &channeldb.HTLCSettleInfo{ + &paymentsdb.HTLCSettleInfo{ Preimage: result.Preimage, SettleTime: p.router.cfg.Clock.Now(), }, @@ -1199,7 +1204,7 @@ func (p *paymentLifecycle) handleAttemptResult(attempt *channeldb.HTLCAttempt, // tower. An attemptResult is returned, indicating the final outcome of this // HTLC attempt. func (p *paymentLifecycle) collectAndHandleResult( - attempt *channeldb.HTLCAttempt) (*attemptResult, error) { + attempt *paymentsdb.HTLCAttempt) (*attemptResult, error) { result, err := p.collectResult(attempt) if err != nil { diff --git a/routing/payment_lifecycle_test.go b/routing/payment_lifecycle_test.go index 0ee751196f9..8094f507574 100644 --- a/routing/payment_lifecycle_test.go +++ b/routing/payment_lifecycle_test.go @@ -8,13 +8,13 @@ import ( "time" "github.com/btcsuite/btcd/btcec/v2" - "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/lnmock" "github.com/lightningnetwork/lnd/lntest/wait" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" + paymentsdb "github.com/lightningnetwork/lnd/payments/db" "github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/routing/route" "github.com/stretchr/testify/mock" @@ -116,7 +116,7 @@ func newTestPaymentLifecycle(t *testing.T) (*paymentLifecycle, *mockers) { // Overwrite the collectResultAsync to focus on testing the payment // lifecycle within the goroutine. - resultCollector := func(attempt *channeldb.HTLCAttempt) { + resultCollector := func(attempt *paymentsdb.HTLCAttempt) { mockers.collectResultsCount++ } p.resultCollector = resultCollector @@ -147,7 +147,7 @@ func setupTestPaymentLifecycle(t *testing.T) (*paymentLifecycle, *mockers) { m.payment, nil, ).Once() - htlcs := []channeldb.HTLCAttempt{} + htlcs := []paymentsdb.HTLCAttempt{} m.payment.On("InFlightHTLCs").Return(htlcs).Once() return p, m @@ -258,11 +258,11 @@ func createDummyRoute(t *testing.T, amt lnwire.MilliSatoshi) *route.Route { } func makeSettledAttempt(t *testing.T, total int, - preimage lntypes.Preimage) *channeldb.HTLCAttempt { + preimage lntypes.Preimage) *paymentsdb.HTLCAttempt { - a := &channeldb.HTLCAttempt{ + a := &paymentsdb.HTLCAttempt{ HTLCAttemptInfo: makeAttemptInfo(t, total), - Settle: &channeldb.HTLCSettleInfo{Preimage: preimage}, + Settle: &paymentsdb.HTLCSettleInfo{Preimage: preimage}, } hash := preimage.Hash() @@ -271,18 +271,18 @@ func makeSettledAttempt(t *testing.T, total int, return a } -func makeFailedAttempt(t *testing.T, total int) *channeldb.HTLCAttempt { - return &channeldb.HTLCAttempt{ +func makeFailedAttempt(t *testing.T, total int) *paymentsdb.HTLCAttempt { + return &paymentsdb.HTLCAttempt{ HTLCAttemptInfo: makeAttemptInfo(t, total), - Failure: &channeldb.HTLCFailInfo{ - Reason: channeldb.HTLCFailInternal, + Failure: &paymentsdb.HTLCFailInfo{ + Reason: paymentsdb.HTLCFailInternal, }, } } -func makeAttemptInfo(t *testing.T, amt int) channeldb.HTLCAttemptInfo { +func makeAttemptInfo(t *testing.T, amt int) paymentsdb.HTLCAttemptInfo { rt := createDummyRoute(t, lnwire.MilliSatoshi(amt)) - return channeldb.HTLCAttemptInfo{ + return paymentsdb.HTLCAttemptInfo{ Route: *rt, Hash: &lntypes.Hash{1, 2, 3}, } @@ -302,7 +302,7 @@ func TestCheckTimeoutTimedOut(t *testing.T) { // Mock the control tower's `FailPayment` method. ct := &mockControlTower{} ct.On("FailPayment", - p.identifier, channeldb.FailureReasonTimeout).Return(nil) + p.identifier, paymentsdb.FailureReasonTimeout).Return(nil) // Mount the mocked control tower. p.router.cfg.Control = ct @@ -323,7 +323,7 @@ func TestCheckTimeoutTimedOut(t *testing.T) { // Mock `FailPayment` to return a dummy error. ct = &mockControlTower{} ct.On("FailPayment", - p.identifier, channeldb.FailureReasonTimeout).Return(errDummy) + p.identifier, paymentsdb.FailureReasonTimeout).Return(errDummy) // Mount the mocked control tower. p.router.cfg.Control = ct @@ -378,7 +378,7 @@ func TestRequestRouteSucceed(t *testing.T) { p.paySession = paySession // Create a dummy payment state. - ps := &channeldb.MPPaymentState{ + ps := &paymentsdb.MPPaymentState{ NumAttemptsInFlight: 1, RemainingAmt: 1, FeesPaid: 100, @@ -415,7 +415,7 @@ func TestRequestRouteHandleCriticalErr(t *testing.T) { p.paySession = paySession // Create a dummy payment state. - ps := &channeldb.MPPaymentState{ + ps := &paymentsdb.MPPaymentState{ NumAttemptsInFlight: 1, RemainingAmt: 1, FeesPaid: 100, @@ -449,7 +449,7 @@ func TestRequestRouteHandleNoRouteErr(t *testing.T) { p, m := newTestPaymentLifecycle(t) // Create a dummy payment state. - ps := &channeldb.MPPaymentState{ + ps := &paymentsdb.MPPaymentState{ NumAttemptsInFlight: 1, RemainingAmt: 1, FeesPaid: 100, @@ -467,7 +467,7 @@ func TestRequestRouteHandleNoRouteErr(t *testing.T) { // The payment should be failed with reason no route. m.control.On("FailPayment", - p.identifier, channeldb.FailureReasonNoRoute, + p.identifier, paymentsdb.FailureReasonNoRoute, ).Return(nil).Once() result, err := p.requestRoute(ps) @@ -498,7 +498,7 @@ func TestRequestRouteFailPaymentError(t *testing.T) { p.paySession = paySession // Create a dummy payment state with zero inflight attempts. - ps := &channeldb.MPPaymentState{ + ps := &paymentsdb.MPPaymentState{ NumAttemptsInFlight: 0, RemainingAmt: 1, FeesPaid: 100, @@ -819,14 +819,14 @@ func TestResumePaymentFailOnTimeout(t *testing.T) { m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once() // 2. calls `GetState` and return the state. - ps := &channeldb.MPPaymentState{ + ps := &paymentsdb.MPPaymentState{ RemainingAmt: paymentAmt, } m.payment.On("GetState").Return(ps).Once() // NOTE: GetStatus is only used to populate the logs which is not // critical, so we loosen the checks on how many times it's been called. - m.payment.On("GetStatus").Return(channeldb.StatusInFlight) + m.payment.On("GetStatus").Return(paymentsdb.StatusInFlight) // 3. make the timeout happens instantly and sleep one millisecond to // make sure it timed out. @@ -837,7 +837,7 @@ func TestResumePaymentFailOnTimeout(t *testing.T) { // 4. the payment should be failed with reason timeout. m.control.On("FailPayment", - p.identifier, channeldb.FailureReasonTimeout, + p.identifier, paymentsdb.FailureReasonTimeout, ).Return(nil).Once() // 5. decideNextStep now returns stepExit. @@ -848,7 +848,7 @@ func TestResumePaymentFailOnTimeout(t *testing.T) { m.control.On("DeleteFailedAttempts", p.identifier).Return(nil).Once() // 7. the payment returns the failed reason. - reason := channeldb.FailureReasonTimeout + reason := paymentsdb.FailureReasonTimeout m.payment.On("TerminalInfo").Return(nil, &reason) // Send the payment and assert it failed with the timeout reason. @@ -868,25 +868,16 @@ func TestResumePaymentFailOnTimeoutErr(t *testing.T) { // Create a test paymentLifecycle with the initial two calls mocked. p, m := setupTestPaymentLifecycle(t) - paymentAmt := lnwire.MilliSatoshi(10000) - - // We now enter the payment lifecycle loop. - // - // 1. calls `FetchPayment` and return the payment. - m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once() - - // 2. calls `GetState` and return the state. - ps := &channeldb.MPPaymentState{ - RemainingAmt: paymentAmt, - } - m.payment.On("GetState").Return(ps).Once() + // We now enter the payment lifecycle loop, we will check the router + // quit channel in the beginning and quit immediately without reloading + // the payment. // NOTE: GetStatus is only used to populate the logs which is // not critical so we loosen the checks on how many times it's // been called. - m.payment.On("GetStatus").Return(channeldb.StatusInFlight) + m.payment.On("GetStatus").Return(paymentsdb.StatusInFlight) - // 3. quit the router to return an error. + // Quit the router to return an error. close(p.router.quit) // Send the payment and assert it failed when router is shutting down. @@ -919,21 +910,21 @@ func TestResumePaymentFailContextCancel(t *testing.T) { m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once() // 2. calls `GetState` and return the state. - ps := &channeldb.MPPaymentState{ + ps := &paymentsdb.MPPaymentState{ RemainingAmt: paymentAmt, } m.payment.On("GetState").Return(ps).Once() // NOTE: GetStatus is only used to populate the logs which is not // critical, so we loosen the checks on how many times it's been called. - m.payment.On("GetStatus").Return(channeldb.StatusInFlight) + m.payment.On("GetStatus").Return(paymentsdb.StatusInFlight) // 3. Cancel the context and skip the FailPayment error to trigger the // context cancellation of the payment. cancel() m.control.On( - "FailPayment", p.identifier, channeldb.FailureReasonCanceled, + "FailPayment", p.identifier, paymentsdb.FailureReasonCanceled, ).Return(nil).Once() // 4. decideNextStep now returns stepExit. @@ -944,7 +935,7 @@ func TestResumePaymentFailContextCancel(t *testing.T) { m.control.On("DeleteFailedAttempts", p.identifier).Return(nil).Once() // 6. We will observe FailureReasonError if the context was cancelled. - reason := channeldb.FailureReasonError + reason := paymentsdb.FailureReasonError m.payment.On("TerminalInfo").Return(nil, &reason) // Send the payment and assert it failed with the timeout reason. @@ -972,7 +963,7 @@ func TestResumePaymentFailOnStepErr(t *testing.T) { m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once() // 2. calls `GetState` and return the state. - ps := &channeldb.MPPaymentState{ + ps := &paymentsdb.MPPaymentState{ RemainingAmt: paymentAmt, } m.payment.On("GetState").Return(ps).Once() @@ -980,7 +971,7 @@ func TestResumePaymentFailOnStepErr(t *testing.T) { // NOTE: GetStatus is only used to populate the logs which is // not critical so we loosen the checks on how many times it's // been called. - m.payment.On("GetStatus").Return(channeldb.StatusInFlight) + m.payment.On("GetStatus").Return(paymentsdb.StatusInFlight) // 3. decideNextStep now returns an error. m.payment.On("AllowMoreAttempts").Return(false, errDummy).Once() @@ -1010,7 +1001,7 @@ func TestResumePaymentFailOnRequestRouteErr(t *testing.T) { m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once() // 2. calls `GetState` and return the state. - ps := &channeldb.MPPaymentState{ + ps := &paymentsdb.MPPaymentState{ RemainingAmt: paymentAmt, } m.payment.On("GetState").Return(ps).Once() @@ -1018,7 +1009,7 @@ func TestResumePaymentFailOnRequestRouteErr(t *testing.T) { // NOTE: GetStatus is only used to populate the logs which is // not critical so we loosen the checks on how many times it's // been called. - m.payment.On("GetStatus").Return(channeldb.StatusInFlight) + m.payment.On("GetStatus").Return(paymentsdb.StatusInFlight) // 3. decideNextStep now returns stepProceed. m.payment.On("AllowMoreAttempts").Return(true, nil).Once() @@ -1056,7 +1047,7 @@ func TestResumePaymentFailOnRegisterAttemptErr(t *testing.T) { m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once() // 2. calls `GetState` and return the state. - ps := &channeldb.MPPaymentState{ + ps := &paymentsdb.MPPaymentState{ RemainingAmt: paymentAmt, } m.payment.On("GetState").Return(ps).Once() @@ -1064,7 +1055,7 @@ func TestResumePaymentFailOnRegisterAttemptErr(t *testing.T) { // NOTE: GetStatus is only used to populate the logs which is // not critical so we loosen the checks on how many times it's // been called. - m.payment.On("GetStatus").Return(channeldb.StatusInFlight) + m.payment.On("GetStatus").Return(paymentsdb.StatusInFlight) // 3. decideNextStep now returns stepProceed. m.payment.On("AllowMoreAttempts").Return(true, nil).Once() @@ -1116,7 +1107,7 @@ func TestResumePaymentFailOnSendAttemptErr(t *testing.T) { m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once() // 2. calls `GetState` and return the state. - ps := &channeldb.MPPaymentState{ + ps := &paymentsdb.MPPaymentState{ RemainingAmt: paymentAmt, } m.payment.On("GetState").Return(ps).Once() @@ -1124,7 +1115,7 @@ func TestResumePaymentFailOnSendAttemptErr(t *testing.T) { // NOTE: GetStatus is only used to populate the logs which is // not critical so we loosen the checks on how many times it's // been called. - m.payment.On("GetStatus").Return(channeldb.StatusInFlight) + m.payment.On("GetStatus").Return(paymentsdb.StatusInFlight) // 3. decideNextStep now returns stepProceed. m.payment.On("AllowMoreAttempts").Return(true, nil).Once() @@ -1170,7 +1161,7 @@ func TestResumePaymentFailOnSendAttemptErr(t *testing.T) { // which we'd fail the payment, cancel the shard and fail the attempt. // // `FailPayment` should be called with an internal reason. - reason := channeldb.FailureReasonError + reason := paymentsdb.FailureReasonError m.control.On("FailPayment", p.identifier, reason).Return(nil).Once() // `CancelShard` should be called with the attemptID. @@ -1208,7 +1199,7 @@ func TestResumePaymentSuccess(t *testing.T) { m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once() // 1.2. calls `GetState` and return the state. - ps := &channeldb.MPPaymentState{ + ps := &paymentsdb.MPPaymentState{ RemainingAmt: paymentAmt, } m.payment.On("GetState").Return(ps).Once() @@ -1216,7 +1207,7 @@ func TestResumePaymentSuccess(t *testing.T) { // NOTE: GetStatus is only used to populate the logs which is // not critical so we loosen the checks on how many times it's // been called. - m.payment.On("GetStatus").Return(channeldb.StatusInFlight) + m.payment.On("GetStatus").Return(paymentsdb.StatusInFlight) // 1.3. decideNextStep now returns stepProceed. m.payment.On("AllowMoreAttempts").Return(true, nil).Once() @@ -1309,7 +1300,7 @@ func TestResumePaymentSuccessWithTwoAttempts(t *testing.T) { m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once() // 1.2. calls `GetState` and return the state. - ps := &channeldb.MPPaymentState{ + ps := &paymentsdb.MPPaymentState{ RemainingAmt: paymentAmt, } m.payment.On("GetState").Return(ps).Once() @@ -1317,7 +1308,7 @@ func TestResumePaymentSuccessWithTwoAttempts(t *testing.T) { // NOTE: GetStatus is only used to populate the logs which is // not critical so we loosen the checks on how many times it's // been called. - m.payment.On("GetStatus").Return(channeldb.StatusInFlight) + m.payment.On("GetStatus").Return(paymentsdb.StatusInFlight) // 1.3. decideNextStep now returns stepProceed. m.payment.On("AllowMoreAttempts").Return(true, nil).Once() @@ -1461,7 +1452,7 @@ func TestCollectResultExitOnErr(t *testing.T) { // which we'd fail the payment, cancel the shard and fail the attempt. // // `FailPayment` should be called with an internal reason. - reason := channeldb.FailureReasonError + reason := paymentsdb.FailureReasonError m.control.On("FailPayment", p.identifier, reason).Return(nil).Once() // `CancelShard` should be called with the attemptID. @@ -1507,7 +1498,7 @@ func TestCollectResultExitOnResultErr(t *testing.T) { // which we'd fail the payment, cancel the shard and fail the attempt. // // `FailPayment` should be called with an internal reason. - reason := channeldb.FailureReasonError + reason := paymentsdb.FailureReasonError m.control.On("FailPayment", p.identifier, reason).Return(nil).Once() // `CancelShard` should be called with the attemptID. @@ -1842,7 +1833,7 @@ func TestReloadInflightAttemptsLegacy(t *testing.T) { m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once() // 2. calls `InFlightHTLCs` and return the attempt. - attempts := []channeldb.HTLCAttempt{*attempt} + attempts := []paymentsdb.HTLCAttempt{*attempt} m.payment.On("InFlightHTLCs").Return(attempts).Once() // 3. Mock the htlcswitch to return a the result chan. diff --git a/routing/payment_session.go b/routing/payment_session.go index cf20251ed4b..bb795211c21 100644 --- a/routing/payment_session.go +++ b/routing/payment_session.go @@ -5,12 +5,12 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btclog/v2" - "github.com/lightningnetwork/lnd/channeldb" graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/lnutils" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/netann" + paymentsdb "github.com/lightningnetwork/lnd/payments/db" "github.com/lightningnetwork/lnd/routing/route" ) @@ -104,7 +104,7 @@ func (e noRouteError) Error() string { } // FailureReason converts a path finding error into a payment-level failure. -func (e noRouteError) FailureReason() channeldb.FailureReason { +func (e noRouteError) FailureReason() paymentsdb.FailureReason { switch e { case errNoTlvPayload, @@ -114,13 +114,13 @@ func (e noRouteError) FailureReason() channeldb.FailureReason { errUnknownRequiredFeature, errMissingDependentFeature: - return channeldb.FailureReasonNoRoute + return paymentsdb.FailureReasonNoRoute case errInsufficientBalance: - return channeldb.FailureReasonInsufficientBalance + return paymentsdb.FailureReasonInsufficientBalance default: - return channeldb.FailureReasonError + return paymentsdb.FailureReasonError } } diff --git a/routing/result_interpretation.go b/routing/result_interpretation.go index 55c4246f58f..713be73eabe 100644 --- a/routing/result_interpretation.go +++ b/routing/result_interpretation.go @@ -5,17 +5,17 @@ import ( "fmt" "io" - "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnwire" + paymentsdb "github.com/lightningnetwork/lnd/payments/db" "github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/tlv" ) // Instantiate variables to allow taking a reference from the failure reason. var ( - reasonError = channeldb.FailureReasonError - reasonIncorrectDetails = channeldb.FailureReasonPaymentDetails + reasonError = paymentsdb.FailureReasonError + reasonIncorrectDetails = paymentsdb.FailureReasonPaymentDetails ) // pairResult contains the result of the interpretation of a payment attempt for @@ -70,7 +70,7 @@ type interpretedResult struct { // finalFailureReason is set to a non-nil value if it makes no more // sense to start another payment attempt. It will contain the reason // why. - finalFailureReason *channeldb.FailureReason + finalFailureReason *paymentsdb.FailureReason // policyFailure is set to a node pair if there is a policy failure on // that connection. This is used to control the second chance logic for diff --git a/routing/router.go b/routing/router.go index f7347c86f1b..1dac1085d49 100644 --- a/routing/router.go +++ b/routing/router.go @@ -15,7 +15,6 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/amp" - "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" @@ -175,7 +174,7 @@ type MissionControlQuerier interface { // need to be made. ReportPaymentFail(attemptID uint64, rt *route.Route, failureSourceIdx *int, failure lnwire.FailureMessage) ( - *channeldb.FailureReason, error) + *paymentsdb.FailureReason, error) // ReportPaymentSuccess reports a successful payment to mission control // as input for future probability estimates. @@ -999,7 +998,7 @@ func (r *ChannelRouter) PreparePayment(payment *LightningPayment) ( // already in-flight. // // TODO(roasbeef): store records as part of creation info? - info := &channeldb.PaymentCreationInfo{ + info := &paymentsdb.PaymentCreationInfo{ PaymentIdentifier: payment.Identifier(), Value: payment.Amount, CreationTime: r.cfg.Clock.Now(), @@ -1038,7 +1037,7 @@ func (r *ChannelRouter) PreparePayment(payment *LightningPayment) ( // SendToRoute sends a payment using the provided route and fails the payment // when an error is returned from the attempt. func (r *ChannelRouter) SendToRoute(htlcHash lntypes.Hash, rt *route.Route, - firstHopCustomRecords lnwire.CustomRecords) (*channeldb.HTLCAttempt, + firstHopCustomRecords lnwire.CustomRecords) (*paymentsdb.HTLCAttempt, error) { return r.sendToRoute(htlcHash, rt, false, firstHopCustomRecords) @@ -1048,7 +1047,7 @@ func (r *ChannelRouter) SendToRoute(htlcHash lntypes.Hash, rt *route.Route, // the payment ONLY when a terminal error is returned from the attempt. func (r *ChannelRouter) SendToRouteSkipTempErr(htlcHash lntypes.Hash, rt *route.Route, - firstHopCustomRecords lnwire.CustomRecords) (*channeldb.HTLCAttempt, + firstHopCustomRecords lnwire.CustomRecords) (*paymentsdb.HTLCAttempt, error) { return r.sendToRoute(htlcHash, rt, true, firstHopCustomRecords) @@ -1062,13 +1061,13 @@ func (r *ChannelRouter) SendToRouteSkipTempErr(htlcHash lntypes.Hash, // the payment won't be failed unless a terminal error has occurred. func (r *ChannelRouter) sendToRoute(htlcHash lntypes.Hash, rt *route.Route, skipTempErr bool, - firstHopCustomRecords lnwire.CustomRecords) (*channeldb.HTLCAttempt, + firstHopCustomRecords lnwire.CustomRecords) (*paymentsdb.HTLCAttempt, error) { // Helper function to fail a payment. It makes sure the payment is only // failed once so that the failure reason is not overwritten. failPayment := func(paymentIdentifier lntypes.Hash, - reason channeldb.FailureReason) error { + reason paymentsdb.FailureReason) error { payment, fetchErr := r.cfg.Control.FetchPayment( paymentIdentifier, @@ -1122,7 +1121,7 @@ func (r *ChannelRouter) sendToRoute(htlcHash lntypes.Hash, rt *route.Route, // Record this payment hash with the ControlTower, ensuring it is not // already in-flight. - info := &channeldb.PaymentCreationInfo{ + info := &paymentsdb.PaymentCreationInfo{ PaymentIdentifier: paymentIdentifier, Value: amt, CreationTime: r.cfg.Clock.Now(), @@ -1191,7 +1190,7 @@ func (r *ChannelRouter) sendToRoute(htlcHash lntypes.Hash, rt *route.Route, // Since for SendToRoute we won't retry in case the shard fails, we'll // mark the payment failed with the control tower immediately if the // skipTempErr is false. - reason := channeldb.FailureReasonError + reason := paymentsdb.FailureReasonError // If we failed to send the HTLC, we need to further decide if we want // to fail the payment. @@ -1448,7 +1447,7 @@ func (r *ChannelRouter) resumePayments() error { } // launchPayment is a helper closure that handles resuming the payment. - launchPayment := func(payment *channeldb.MPPayment) { + launchPayment := func(payment *paymentsdb.MPPayment) { defer r.wg.Done() // Get the hashes used for the outstanding HTLCs. @@ -1523,7 +1522,7 @@ func (r *ChannelRouter) resumePayments() error { // attempt to NOT be saved, resulting a payment being stuck forever. More info: // - https://github.com/lightningnetwork/lnd/issues/8146 // - https://github.com/lightningnetwork/lnd/pull/8174 -func (r *ChannelRouter) failStaleAttempt(a channeldb.HTLCAttempt, +func (r *ChannelRouter) failStaleAttempt(a paymentsdb.HTLCAttempt, payHash lntypes.Hash) { // We can only fail inflight HTLCs so we skip the settled/failed ones. @@ -1605,8 +1604,8 @@ func (r *ChannelRouter) failStaleAttempt(a channeldb.HTLCAttempt, // Fail the attempt in db. If there's an error, there's nothing we can // do here but logging it. - failInfo := &channeldb.HTLCFailInfo{ - Reason: channeldb.HTLCFailUnknown, + failInfo := &paymentsdb.HTLCFailInfo{ + Reason: paymentsdb.HTLCFailUnknown, FailTime: r.cfg.Clock.Now(), } _, err = r.cfg.Control.FailAttempt(payHash, a.AttemptID, failInfo) diff --git a/routing/router_test.go b/routing/router_test.go index d149c423e00..702c29db95e 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -21,7 +21,6 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/davecgh/go-spew/spew" sphinx "github.com/lightningnetwork/lightning-onion" - "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph" @@ -32,6 +31,7 @@ import ( "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" + paymentsdb "github.com/lightningnetwork/lnd/payments/db" "github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/zpay32" @@ -1091,7 +1091,7 @@ func TestSendPaymentErrorPathPruning(t *testing.T) { // The final error returned should also indicate that the peer wasn't // online (the last error we returned). - require.Equal(t, channeldb.FailureReasonNoRoute, err) + require.Equal(t, paymentsdb.FailureReasonNoRoute, err) // Inspect the two attempts that were made before the payment failed. p, err := ctx.router.cfg.Control.FetchPayment(*payment.paymentHash) @@ -2429,7 +2429,7 @@ func TestSendToRouteSkipTempErrPermanentFailure(t *testing.T) { mock.Anything, mock.Anything, mock.Anything, ).Return(permErr) - failureReason := channeldb.FailureReasonPaymentDetails + failureReason := paymentsdb.FailureReasonPaymentDetails missionControl.On("ReportPaymentFail", mock.Anything, rt, mock.Anything, mock.Anything, ).Return(&failureReason, nil) diff --git a/rpcserver.go b/rpcserver.go index c922b23e4fd..7add215799f 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -71,6 +71,7 @@ import ( "github.com/lightningnetwork/lnd/lnwallet/chanfunding" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/macaroons" + paymentsdb "github.com/lightningnetwork/lnd/payments/db" "github.com/lightningnetwork/lnd/peer" "github.com/lightningnetwork/lnd/peernotifier" "github.com/lightningnetwork/lnd/record" @@ -5892,7 +5893,7 @@ func (r *rpcServer) dispatchPaymentIntent( payment, ) } else { - var attempt *channeldb.HTLCAttempt + var attempt *paymentsdb.HTLCAttempt attempt, routerErr = r.server.chanRouter.SendToRoute( payIntent.rHash, payIntent.route, nil, ) @@ -7512,7 +7513,7 @@ func (r *rpcServer) ListPayments(ctx context.Context, } } - query := channeldb.PaymentsQuery{ + query := paymentsdb.Query{ IndexOffset: req.IndexOffset, MaxPayments: req.MaxPayments, Reversed: req.Reversed, @@ -7528,7 +7529,7 @@ func (r *rpcServer) ListPayments(ctx context.Context, query.MaxPayments = math.MaxUint64 } - paymentsQuerySlice, err := r.server.kvPaymentsDB.QueryPayments( + paymentsQuerySlice, err := r.server.paymentsDB.QueryPayments( ctx, query, ) if err != nil { @@ -7611,7 +7612,7 @@ func (r *rpcServer) DeletePayment(ctx context.Context, rpcsLog.Infof("[DeletePayment] payment_identifier=%v, "+ "failed_htlcs_only=%v", hash, req.FailedHtlcsOnly) - err = r.server.kvPaymentsDB.DeletePayment(hash, req.FailedHtlcsOnly) + err = r.server.paymentsDB.DeletePayment(hash, req.FailedHtlcsOnly) if err != nil { return nil, err } @@ -7651,7 +7652,7 @@ func (r *rpcServer) DeleteAllPayments(ctx context.Context, "failed_htlcs_only=%v", req.FailedPaymentsOnly, req.FailedHtlcsOnly) - numDeletedPayments, err := r.server.kvPaymentsDB.DeletePayments( + numDeletedPayments, err := r.server.paymentsDB.DeletePayments( req.FailedPaymentsOnly, req.FailedHtlcsOnly, ) if err != nil { diff --git a/scripts/keys/hieblmi.asc b/scripts/keys/hieblmi.asc index a5fdfdbb6eb..dd034984379 100644 --- a/scripts/keys/hieblmi.asc +++ b/scripts/keys/hieblmi.asc @@ -11,20 +11,20 @@ rRZoszUECFJQRDighatiXUjZFutWKhZC41OknRAPjCJ5KX0cs0/lFh5SLE7u+v+G Wr2S5cfyvfYrTqQfMUY5j2BOzH2ZC+GtRi2sGjnS8l78ncP0SkdLtDCnmb3KD+CY hj9MjFtUpHmcwd+QHcxGYDzSQDOENi02qleMCPO+rsCEPVXTdTaTIOKmg0ZvklUE THxI2CTdoh3K4a/eOsLscSTUWy5IZXJWuzmolB77B3uzqledJVfmpeO+hQARAQAB -tCNTbHlnaHRuaW5nIDxoaWVibG1pQHByb3Rvbm1haWwuY29tPokCVwQTAQgAQRYh -BDL36h56Azn303FkufgtRW6gI8m/BQJkegPyAhsDBQkB4TOABQsJCAcCAiICBhUK -CQgLAgQWAgMBAh4HAheAAAoJEPgtRW6gI8m/QREP+wdNFRE6cXn6rpROVORnCC+p -BiKBS+80ODncM98UiYbU6BKNCjCtKH4twWnYOEivVJt/CVDE6vd/fHHlJMWQ/SSE -5xK0jS3MrCPqgdbeu73wecrrD48bURBj9ZEV8INQshXliY43i6FmztvdUWHuFEAI -BpODhN/LDKUSkGVx1ov4H2WthFTY/rdaAtJ32WqVcTdAnpH7XIs0Moz2CJEmWzLP -VD4qz5r5tpcQ8ssZ90Pndtm2OY3N77DQZdyCu/IQKzdA7zVMwzBgQ42tm3bC0dRE -J26QgeDVY9sfOBhUT0nQa4c62a8jg4Kcb1yq3z2XbBHBMrq+H+VxyjKrUhyNRUdo -9qyvWWVJmzFU2AskaMq//T0OTZLyWPcsta57nWBMf3UBzCCTsVZJFTODKdoDnK3C -jmLzd/sg5FdryDfi1/Epww6MQ6l1gMmUELaWwQ+2fl03GlUpH+jtwDTGhMdX48/c -8qGkzyZogyhiqm1XdCvk1PCFk5VHFvzxa3CuXS90sTUnQO0Iwo1HQ9lhEtj11wwh -91opA73HiPT+HxWR1o1MO88z3PpkuzHqZnJjQdwDdA8smbA/vmOU/0kItiUmJn6G -NdKKJRFJVfcAtygDke+5qSaJj2rmDu6t4Rnra8Zk5vycf+XyOIvaR/gTOGye0JLo -XXb4GMpxqhs7DqrZQSN0uQINBGR6A/IBEADSEs3XC7g44AykQ4s/INazgbtxXLxY +tCNTbHlnaHRuaW5nIDxoaWVibG1pQHByb3Rvbm1haWwuY29tPokCVwQTAQgAQQIb +AwULCQgHAgIiAgYVCgkICwIEFgIDAQIeBwIXgBYhBDL36h56Azn303FkufgtRW6g +I8m/BQJmXLJVBQkLSK/jAAoJEPgtRW6gI8m/hy8P/jRi+yR9Lo78SPbklI/94iso +zvFAfaAXmU8yVD+kXwhL0ch6+vbZ74DolXGoVByY+OYb5tI4XDbc4/f9r35DABJn +jMCkOBPTfK1lY/kkSn2trc8joEQit43yRCZP9rKvrgXaJtWuYwOpAiaHTR0vylfO +Bhm36yQ4nv+Pv1fgyHcGmtxhYQjNRWgFw2p5/Hzzyc+iNjnCttY4PFBKIiaUUzIl +Vrtz2KWhq4Czvwnd4JXGXzrt+4klqxazGKypv8cnhodVOahx0tDmhnp+1YKerGcV +HljLTNwWBL7eI3oNe1pkZT4BC2meLdOdGDEvY5DgrcnmyLs14m8GOE2otnuNYyXy +jd2bXB1/YTNS+cfCwA6gKBpfXy3PsGkh0Un11NpuWLLsys1WVyOzSy/M11y3VM/E +tvJOXLK/YGMXY55XhqNTRIhsjgLFaKp9L44P2UHqy4hBshRN4hWrQCx5K39TkFTK +f8MR1ep8hiVcpVaArQSULVuTH3+huIJ8Hmc62gOEj5fLpVP98vr32vuoKBCr6SVw +t6YFP1cIfBHrQ19zw+n+56An2CqHjUVacjteQrkE0MKcOrngC0BGg8TLg6mtwnwG +jpJm7ZEYwuvVWULO7JeBXDL2QgaQmM5CGkfg5U1/173W7eChkkEBp9LguxsG6XXA +QG8aozxc5fuTQ6VxjAUmuQINBGR6A/IBEADSEs3XC7g44AykQ4s/INazgbtxXLxY U3J1kTB4aJsZu8LYqAu/jYiSq268ePjueHWjlo9oRgYo1jIYmS/5M2Yv/1PAKRaO bP1544Q1tdfKUBL7um7WWrD9IEo1epwBozMbrvrxMkpJfToIWEawR3UH2/LKpRGo TRoGcfjZrtXUtzks7vQfHusNT92maUmWFPoNr4o/13W+QTOxtiblRr3iPt59NkPQ @@ -35,18 +35,18 @@ BI84yxFRwx5ZThjQuh6SlpgjlVj1jPZ3oECspFv4qxLfKoDPXQV/VGqwfKwVH1nt MUaMeyjbSLIyCZvdB21ygFUF7RbklB3jAlM1jkxC3W95nytjmOUjnwua18sr4f1L l5VJYxcRoY1sLKgKBMTe8/vXSZNzdn6APZ/1LBxDiH9hbSFW0GEweXbLJk3mounX 9C39tNRBwYjWQE/1Mz7hLL1pKGbngJY9AtsS2+xvP/N+P+hm4eT35XeGhsNMJw+0 -9S1YmZzHo04KdQARAQABiQI8BBgBCAAmFiEEMvfqHnoDOffTcWS5+C1FbqAjyb8F -AmR6A/ICGwwFCQHhM4AACgkQ+C1FbqAjyb+KBhAAsQzbo5Bj41wmkJ8Lgr1Najyy -IjNu4gzqH67Ri1PXPUOPRLqYZAFBh6JmbcpkI4ZWtlHBd+wu2n90b7HmT1ZPtQoi -5QtkHpGWr+qpBBl115t58o91phoY+6T1B0T3M6045InZfrL5+1ZMLDQtomc2i8JE -h0wBSxuBZXO3mqS0fygLmKCwtJeQ6cUC4tg5fmSXuBpa5gYX5d0iTNb+4GNCxOG4 -1PSy/P+K1BdjNG6DxJrj0N1cbCO4/pXeWmwVOLIaR6gdMaOTI5on4kFYNlMSG8y0 -VgtECgL6No4Do/jDRPalQAevigx0Pfykco50Xi8iaZ7n+5JrNbqPBZe/JHNs1b8k -WUdBIjLuG5QiIl4DLVCTG6Wpzgw5qkySgqaBxTiEsjmBFrp0gePFcGfBab5ptvSg -xYQFiVk2edon/Cw3OrBgshv1lBooy8+cupPwOXIoCpuN8c1tV0lTV7CachiDjBlv -/ex//apV3ECm/4ka/JVESU8XR17UbOsW0zn2LQyQwd0+QrKjhFGRIP+u2Ntk/9yW -IOI/YqrJe9+mSL/qFZWnfNySDmW4UXYtnHp3peWIi3gO0Du6cgfkOZ5FUxpyK/7h -X+0OTQPyrMNiMlFVMUa4rIYE7Pq5JPHPQJYspkUDb3XPx1zL26vS0QtLnyE8B4Yj -/yn/7SRbazw8Lki8/xY= -=0GPx ------END PGP PUBLIC KEY BLOCK----- \ No newline at end of file +9S1YmZzHo04KdQARAQABiQI8BBgBCAAmAhsMFiEEMvfqHnoDOffTcWS5+C1FbqAj +yb8FAmZfGXsFCQtLFwkACgkQ+C1FbqAjyb8G2Q/9F256fh+fiCvziDTZmsWVNWIP +uPYbmUdBHHoOyW6r7xnbJo5HO/wszi7Bd9JlcCGvoRixuxIIb4RU39znawnmIvAh ++5tp5XmWQ76IZQ8hfhCOjyoAvmUVpsGMkIa8GUEkbL3uT0ernVavz6h1NgpqRsLB +y2vcJoSFMLjf+hgN6Mo+tnj6BNWyjNZnMoPW90fTCePFbCypr8zdcFNYo9ZU3+NT +78ctB4fFvfJdZZiZwF6HKioXA6I6Jk+z+fHeDB/ZAm2vEBlz1DcFRanJpo1wnaVP +B5CGBjq2oqO3yHUNYBqM+/mL7yV2xrK2bIRcPOh1nKyBZvp3bJ0G3gPgSj0osSEX +zWn8tR9TC6uLb0vFZ0WHh6kghXaLdjjEZOvBNDMJ1YB9curDgjdeK74/T0d/GlmI +kDM7zzrEcO3FUbSxuXyIlCJ+hZHFkHV1XX7W+pNWuyX1j2A8PazfFKFRpdGpeNfn +YKOncIbUYoeNm81rgJnEKxi8+LfaeisiE8CXb5UtjHIEDQ5DctCtZZRwT8OXbIx0 +16uJsmbVQ68nFtttQPNrASZAASddwSBmHsLxeQ40/N+vukuy9KwQrfY4Y705xKq1 +ivGh9zNyWky1a1OZ7MeEPByP9urRZly3HbJ+EY65KfanSoEWXulDYvPZJGHT29Tk +D/a7eqozHTGiHNwQUb4= +=F8SH +-----END PGP PUBLIC KEY BLOCK----- diff --git a/server.go b/server.go index d79cc354722..fcb9281fda1 100644 --- a/server.go +++ b/server.go @@ -65,6 +65,7 @@ import ( "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/nat" "github.com/lightningnetwork/lnd/netann" + paymentsdb "github.com/lightningnetwork/lnd/payments/db" "github.com/lightningnetwork/lnd/peer" "github.com/lightningnetwork/lnd/peernotifier" "github.com/lightningnetwork/lnd/pool" @@ -335,11 +336,9 @@ type server struct { invoicesDB invoices.InvoiceDB - // kvPaymentsDB is the DB that contains all functions for managing + // paymentsDB is the DB that contains all functions for managing // payments. - // - // TODO(ziggie): Replace with interface. - kvPaymentsDB *channeldb.KVPaymentsDB + paymentsDB paymentsdb.PaymentDB aliasMgr *aliasmgr.Manager @@ -634,8 +633,9 @@ func newServer(ctx context.Context, cfg *Config, listenAddrs []net.Addr, if cfg.ProtocolOptions.TaprootOverlayChans && implCfg.AuxFundingController.IsNone() { - return nil, fmt.Errorf("taproot overlay flag set, but not " + - "aux controllers") + return nil, fmt.Errorf("taproot overlay flag set, but " + + "overlay channels are not supported " + + "in a standalone lnd build") } //nolint:ll @@ -684,7 +684,7 @@ func newServer(ctx context.Context, cfg *Config, listenAddrs []net.Addr, addrSource: addrSource, miscDB: dbs.ChanStateDB, invoicesDB: dbs.InvoiceDB, - kvPaymentsDB: dbs.KVPaymentsDB, + paymentsDB: dbs.PaymentsDB, cc: cc, sigPool: lnwallet.NewSigPool(cfg.Workers.Sig, cc.Signer), writePool: writePool, @@ -1134,7 +1134,7 @@ func newServer(ctx context.Context, cfg *Config, listenAddrs []net.Addr, PathFindingConfig: pathFindingConfig, } - s.controlTower = routing.NewControlTower(dbs.KVPaymentsDB) + s.controlTower = routing.NewControlTower(dbs.PaymentsDB) strictPruning := cfg.Bitcoin.Node == "neutrino" || cfg.Routing.StrictZombiePruning diff --git a/sqldb/go.mod b/sqldb/go.mod index dbe0d20095d..213042a2dc0 100644 --- a/sqldb/go.mod +++ b/sqldb/go.mod @@ -75,4 +75,4 @@ require ( modernc.org/token v1.1.0 // indirect ) -go 1.23.10 +go 1.23.12 diff --git a/sqldb/sqlc/db_custom.go b/sqldb/sqlc/db_custom.go index 2490e5feb37..f7bc499185a 100644 --- a/sqldb/sqlc/db_custom.go +++ b/sqldb/sqlc/db_custom.go @@ -37,3 +37,127 @@ func makeQueryParams(numTotalArgs, numListArgs int) string { return b.String() } + +// ChannelAndNodes is an interface that provides access to a channel and its +// two nodes. +type ChannelAndNodes interface { + // Channel returns the GraphChannel associated with this interface. + Channel() GraphChannel + + // Node1 returns the first GraphNode associated with this channel. + Node1() GraphNode + + // Node2 returns the second GraphNode associated with this channel. + Node2() GraphNode +} + +// Channel returns the GraphChannel associated with this interface. +// +// NOTE: This method is part of the ChannelAndNodes interface. +func (r GetChannelsByPolicyLastUpdateRangeRow) Channel() GraphChannel { + return r.GraphChannel +} + +// Node1 returns the first GraphNode associated with this channel. +// +// NOTE: This method is part of the ChannelAndNodes interface. +func (r GetChannelsByPolicyLastUpdateRangeRow) Node1() GraphNode { + return r.GraphNode +} + +// Node2 returns the second GraphNode associated with this channel. +// +// NOTE: This method is part of the ChannelAndNodes interface. +func (r GetChannelsByPolicyLastUpdateRangeRow) Node2() GraphNode { + return r.GraphNode_2 +} + +// ChannelAndNodeIDs is an interface that provides access to a channel and its +// two node public keys. +type ChannelAndNodeIDs interface { + // Channel returns the GraphChannel associated with this interface. + Channel() GraphChannel + + // Node1Pub returns the public key of the first node as a byte slice. + Node1Pub() []byte + + // Node2Pub returns the public key of the second node as a byte slice. + Node2Pub() []byte +} + +// Channel returns the GraphChannel associated with this interface. +// +// NOTE: This method is part of the ChannelAndNodeIDs interface. +func (r GetChannelsBySCIDWithPoliciesRow) Channel() GraphChannel { + return r.GraphChannel +} + +// Node1Pub returns the public key of the first node as a byte slice. +// +// NOTE: This method is part of the ChannelAndNodeIDs interface. +func (r GetChannelsBySCIDWithPoliciesRow) Node1Pub() []byte { + return r.GraphNode.PubKey +} + +// Node2Pub returns the public key of the second node as a byte slice. +// +// NOTE: This method is part of the ChannelAndNodeIDs interface. +func (r GetChannelsBySCIDWithPoliciesRow) Node2Pub() []byte { + return r.GraphNode_2.PubKey +} + +// Node1 returns the first GraphNode associated with this channel. +// +// NOTE: This method is part of the ChannelAndNodes interface. +func (r GetChannelsBySCIDWithPoliciesRow) Node1() GraphNode { + return r.GraphNode +} + +// Node2 returns the second GraphNode associated with this channel. +// +// NOTE: This method is part of the ChannelAndNodes interface. +func (r GetChannelsBySCIDWithPoliciesRow) Node2() GraphNode { + return r.GraphNode_2 +} + +// Channel returns the GraphChannel associated with this interface. +// +// NOTE: This method is part of the ChannelAndNodeIDs interface. +func (r GetChannelsByOutpointsRow) Channel() GraphChannel { + return r.GraphChannel +} + +// Node1Pub returns the public key of the first node as a byte slice. +// +// NOTE: This method is part of the ChannelAndNodeIDs interface. +func (r GetChannelsByOutpointsRow) Node1Pub() []byte { + return r.Node1Pubkey +} + +// Node2Pub returns the public key of the second node as a byte slice. +// +// NOTE: This method is part of the ChannelAndNodeIDs interface. +func (r GetChannelsByOutpointsRow) Node2Pub() []byte { + return r.Node2Pubkey +} + +// Channel returns the GraphChannel associated with this interface. +// +// NOTE: This method is part of the ChannelAndNodeIDs interface. +func (r GetChannelsBySCIDRangeRow) Channel() GraphChannel { + return r.GraphChannel +} + +// Node1Pub returns the public key of the first node as a byte slice. +// +// NOTE: This method is part of the ChannelAndNodeIDs interface. +func (r GetChannelsBySCIDRangeRow) Node1Pub() []byte { + return r.Node1PubKey +} + +// Node2Pub returns the public key of the second node as a byte slice. +// +// NOTE: This method is part of the ChannelAndNodeIDs interface. +func (r GetChannelsBySCIDRangeRow) Node2Pub() []byte { + return r.Node2PubKey +} diff --git a/tlv/go.mod b/tlv/go.mod index 383f6550a84..53198338688 100644 --- a/tlv/go.mod +++ b/tlv/go.mod @@ -22,4 +22,4 @@ require ( gopkg.in/yaml.v3 v3.0.1 // indirect ) -go 1.23.10 +go 1.23.12 diff --git a/tools/Dockerfile b/tools/Dockerfile index 43a34058ef5..a9b39edb2e3 100644 --- a/tools/Dockerfile +++ b/tools/Dockerfile @@ -1,4 +1,4 @@ -FROM golang:1.23.10 +FROM golang:1.23.12 RUN apt-get update && apt-get install -y git ENV GOCACHE=/tmp/build/.cache diff --git a/tools/go.mod b/tools/go.mod index 5d40dc35505..2b81efe6f9a 100644 --- a/tools/go.mod +++ b/tools/go.mod @@ -1,6 +1,6 @@ module github.com/lightningnetwork/lnd/tools -go 1.23.10 +go 1.23.12 require ( github.com/btcsuite/btcd v0.24.2 diff --git a/tor/go.mod b/tor/go.mod index 98869a20bc0..2a12e29daca 100644 --- a/tor/go.mod +++ b/tor/go.mod @@ -23,4 +23,4 @@ require ( gopkg.in/yaml.v3 v3.0.1 // indirect ) -go 1.23.10 +go 1.23.12