diff --git a/config.go b/config.go index f9aec0284da..38499508797 100644 --- a/config.go +++ b/config.go @@ -626,8 +626,9 @@ func DefaultConfig() Config { RejectCacheSize: channeldb.DefaultRejectCacheSize, ChannelCacheSize: channeldb.DefaultChannelCacheSize, }, - Prometheus: lncfg.DefaultPrometheus(), - Watchtower: lncfg.DefaultWatchtowerCfg(defaultTowerDir), + Prometheus: lncfg.DefaultPrometheus(), + Watchtower: lncfg.DefaultWatchtowerCfg(defaultTowerDir), + ProtocolOptions: lncfg.DefaultProtocol(), HealthChecks: &lncfg.HealthCheckConfig{ ChainCheck: &lncfg.CheckConfig{ Interval: defaultChainInterval, diff --git a/contractcourt/htlc_incoming_contest_resolver.go b/contractcourt/htlc_incoming_contest_resolver.go index d43a50d906d..9f08f0a7c67 100644 --- a/contractcourt/htlc_incoming_contest_resolver.go +++ b/contractcourt/htlc_incoming_contest_resolver.go @@ -7,7 +7,6 @@ import ( "fmt" "io" - "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/txscript" "github.com/lightningnetwork/lnd/channeldb" @@ -18,7 +17,6 @@ import ( "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/queue" - "github.com/lightningnetwork/lnd/tlv" ) // htlcIncomingContestResolver is a ContractResolver that's able to resolve an @@ -522,18 +520,15 @@ func (h *htlcIncomingContestResolver) Supplement(htlc channeldb.HTLC) { func (h *htlcIncomingContestResolver) decodePayload() (*hop.Payload, []byte, error) { - var blindingPoint *btcec.PublicKey - h.htlc.BlindingPoint.WhenSome( - func(b tlv.RecordT[lnwire.BlindingPointTlvType, - *btcec.PublicKey]) { - - blindingPoint = b.Val - }, - ) + blindingInfo := hop.ReconstructBlindingInfo{ + IncomingAmt: h.htlc.Amt, + IncomingExpiry: h.htlc.RefundTimeout, + BlindingKey: h.htlc.BlindingPoint, + } onionReader := bytes.NewReader(h.htlc.OnionBlob[:]) iterator, err := h.OnionProcessor.ReconstructHopIterator( - onionReader, h.htlc.RHash[:], blindingPoint, + onionReader, h.htlc.RHash[:], blindingInfo, ) if err != nil { return nil, nil, err diff --git a/contractcourt/htlc_incoming_contest_resolver_test.go b/contractcourt/htlc_incoming_contest_resolver_test.go index d789858fb46..cc3f9c934ff 100644 --- a/contractcourt/htlc_incoming_contest_resolver_test.go +++ b/contractcourt/htlc_incoming_contest_resolver_test.go @@ -6,7 +6,6 @@ import ( "io/ioutil" "testing" - "github.com/btcsuite/btcd/btcec/v2" sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" @@ -290,7 +289,7 @@ type mockOnionProcessor struct { } func (o *mockOnionProcessor) ReconstructHopIterator(r io.Reader, rHash []byte, - blindingPoint *btcec.PublicKey) (hop.Iterator, error) { + _ hop.ReconstructBlindingInfo) (hop.Iterator, error) { data, err := ioutil.ReadAll(r) if err != nil { diff --git a/contractcourt/interfaces.go b/contractcourt/interfaces.go index 146670a4147..a48d2373ebf 100644 --- a/contractcourt/interfaces.go +++ b/contractcourt/interfaces.go @@ -4,7 +4,6 @@ import ( "context" "io" - "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb/models" @@ -42,7 +41,7 @@ type OnionProcessor interface { // ReconstructHopIterator attempts to decode a valid sphinx packet from // the passed io.Reader instance. ReconstructHopIterator(r io.Reader, rHash []byte, - blindingKey *btcec.PublicKey) (hop.Iterator, error) + blindingInfo hop.ReconstructBlindingInfo) (hop.Iterator, error) } // UtxoSweeper defines the sweep functions that contract court requires. diff --git a/docs/release-notes/release-notes-0.18.0.md b/docs/release-notes/release-notes-0.18.0.md index c4ca138e797..55cd9dd60aa 100644 --- a/docs/release-notes/release-notes-0.18.0.md +++ b/docs/release-notes/release-notes-0.18.0.md @@ -187,7 +187,9 @@ call where arguments were swapped. bitcoin peers' feefilter values into account](https://github.com/lightningnetwork/lnd/pull/8418). * [Preparatory work](https://github.com/lightningnetwork/lnd/pull/8159) for - forwarding of blinded routes was added. + forwarding of blinded routes was added, along with [support](https://github.com/lightningnetwork/lnd/pull/8160) + for forwarding blinded payments. Forwarding of blinded payments is disabled + by default, and the feature is not yet advertised to the network. ## RPC Additions diff --git a/htlcswitch/hop/forwarding_info.go b/htlcswitch/hop/forwarding_info.go index 3ec358a0acb..5a1463c4853 100644 --- a/htlcswitch/hop/forwarding_info.go +++ b/htlcswitch/hop/forwarding_info.go @@ -22,4 +22,9 @@ type ForwardingInfo struct { // OutgoingCTLV is the specified value of the CTLV timelock to be used // in the outgoing HTLC. OutgoingCTLV uint32 + + // NextBlinding is an optional blinding point to be passed to the next + // node in UpdateAddHtlc. This field is set if the htlc is part of a + // blinded route. + NextBlinding lnwire.BlindingPointRecord } diff --git a/htlcswitch/hop/fuzz_test.go b/htlcswitch/hop/fuzz_test.go index 7d528a1c42e..cbbe882601f 100644 --- a/htlcswitch/hop/fuzz_test.go +++ b/htlcswitch/hop/fuzz_test.go @@ -117,7 +117,7 @@ func fuzzPayload(f *testing.F, finalPayload bool) { r := bytes.NewReader(data) - payload1, err := NewPayloadFromReader(r, finalPayload) + payload1, _, err := NewPayloadFromReader(r, finalPayload) if err != nil { return } @@ -146,7 +146,7 @@ func fuzzPayload(f *testing.F, finalPayload bool) { require.NoError(t, err) } - payload2, err := NewPayloadFromReader(&b, finalPayload) + payload2, _, err := NewPayloadFromReader(&b, finalPayload) require.NoError(t, err) require.Equal(t, payload1, payload2) diff --git a/htlcswitch/hop/iterator.go b/htlcswitch/hop/iterator.go index 1829522f48f..df6f5aac727 100644 --- a/htlcswitch/hop/iterator.go +++ b/htlcswitch/hop/iterator.go @@ -2,6 +2,7 @@ package hop import ( "bytes" + "errors" "fmt" "io" "sync" @@ -9,6 +10,13 @@ import ( "github.com/btcsuite/btcd/btcec/v2" sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/record" + "github.com/lightningnetwork/lnd/tlv" +) + +var ( + // ErrDecodeFailed is returned when we can't decode blinded data. + ErrDecodeFailed = errors.New("could not decode blinded data") ) // Iterator is an interface that abstracts away the routing information @@ -47,16 +55,24 @@ type sphinxHopIterator struct { // includes the information required to properly forward the packet to // the next hop. processedPacket *sphinx.ProcessedPacket + + // blindingKit contains the elements required to process hops that are + // part of a blinded route. + blindingKit BlindingKit } // makeSphinxHopIterator converts a processed packet returned from a sphinx -// router and converts it into an hop iterator for usage in the link. +// router and converts it into an hop iterator for usage in the link. A +// blinding kit is passed through for the link to obtain forwarding information +// for blinded routes. func makeSphinxHopIterator(ogPacket *sphinx.OnionPacket, - packet *sphinx.ProcessedPacket) *sphinxHopIterator { + packet *sphinx.ProcessedPacket, + blindingKit BlindingKit) *sphinxHopIterator { return &sphinxHopIterator{ ogPacket: ogPacket, processedPacket: packet, + blindingKit: blindingKit, } } @@ -90,10 +106,29 @@ func (r *sphinxHopIterator) HopPayload() (*Payload, error) { // Otherwise, if this is the TLV payload, then we'll make a new stream // to decode only what we need to make routing decisions. case sphinx.PayloadTLV: - return NewPayloadFromReader( + isFinal := r.processedPacket.Action == sphinx.ExitNode + payload, parsed, err := NewPayloadFromReader( bytes.NewReader(r.processedPacket.Payload.Payload), - r.processedPacket.Action == sphinx.ExitNode, + isFinal, ) + if err != nil { + return nil, err + } + + // If we had an encrypted data payload present, pull out our + // forwarding info from the blob. + if payload.encryptedData != nil { + fwdInfo, err := r.blindingKit.DecryptAndValidateFwdInfo( + payload, isFinal, parsed, + ) + if err != nil { + return nil, err + } + + payload.FwdInfo = *fwdInfo + } + + return payload, err default: return nil, fmt.Errorf("unknown sphinx payload type: %v", @@ -113,6 +148,221 @@ func (r *sphinxHopIterator) ExtractErrorEncrypter( return extracter(r.ogPacket.EphemeralKey) } +// BlindingProcessor is an interface that provides the cryptographic operations +// required for processing blinded hops. +// +// This interface is extracted to allow more granular testing of blinded +// forwarding calculations. +type BlindingProcessor interface { + // DecryptBlindedHopData decrypts a blinded blob of data using the + // ephemeral key provided. + DecryptBlindedHopData(ephemPub *btcec.PublicKey, + encryptedData []byte) ([]byte, error) + + // NextEphemeral returns the next hop's ephemeral key, calculated + // from the current ephemeral key provided. + NextEphemeral(*btcec.PublicKey) (*btcec.PublicKey, error) +} + +// BlindingKit contains the components required to extract forwarding +// information for hops in a blinded route. +type BlindingKit struct { + // Processor provides the low-level cryptographic operations to + // handle an encrypted blob of data in a blinded forward. + Processor BlindingProcessor + + // UpdateAddBlinding holds a blinding point that was passed to the + // node via update_add_htlc's TLVs. + UpdateAddBlinding lnwire.BlindingPointRecord + + // IncomingCltv is the expiry of the incoming HTLC. + IncomingCltv uint32 + + // IncomingAmount is the amount of the incoming HTLC. + IncomingAmount lnwire.MilliSatoshi +} + +// validateBlindingPoint validates that only one blinding point is present for +// the hop and returns the relevant one. +func (b *BlindingKit) validateBlindingPoint(payloadBlinding *btcec.PublicKey, + isFinalHop bool) (*btcec.PublicKey, error) { + + // Bolt 04: if encrypted_recipient_data is present: + // - if blinding_point (in update add) is set: + // - MUST error if current_blinding_point is set (in payload) + // - otherwise: + // - MUST return an error if current_blinding_point is not present + // (in payload) + payloadBlindingSet := payloadBlinding != nil + updateBlindingSet := b.UpdateAddBlinding.IsSome() + + switch { + case !(payloadBlindingSet || updateBlindingSet): + return nil, ErrInvalidPayload{ + Type: record.BlindingPointOnionType, + Violation: OmittedViolation, + FinalHop: isFinalHop, + } + + case payloadBlindingSet && updateBlindingSet: + return nil, ErrInvalidPayload{ + Type: record.BlindingPointOnionType, + Violation: IncludedViolation, + FinalHop: isFinalHop, + } + + case payloadBlindingSet: + return payloadBlinding, nil + + case updateBlindingSet: + pk, err := b.UpdateAddBlinding.UnwrapOrErr( + fmt.Errorf("expected update add blinding"), + ) + if err != nil { + return nil, err + } + + return pk.Val, nil + } + + return nil, fmt.Errorf("expected blinded point set") +} + +// DecryptAndValidateFwdInfo performs all operations required to decrypt and +// validate a blinded route. +func (b *BlindingKit) DecryptAndValidateFwdInfo(payload *Payload, + isFinalHop bool, payloadParsed map[tlv.Type][]byte) ( + *ForwardingInfo, error) { + + // We expect this function to be called when we have encrypted data + // present, and a blinding key is set either in the payload or the + // update_add_htlc message. + blindingPoint, err := b.validateBlindingPoint( + payload.blindingPoint, isFinalHop, + ) + if err != nil { + return nil, err + } + + decrypted, err := b.Processor.DecryptBlindedHopData( + blindingPoint, payload.encryptedData, + ) + if err != nil { + return nil, fmt.Errorf("decrypt blinded "+ + "data: %w", err) + } + + buf := bytes.NewBuffer(decrypted) + routeData, err := record.DecodeBlindedRouteData(buf) + if err != nil { + return nil, fmt.Errorf("%w: %w", + ErrDecodeFailed, err) + } + + // Validate the contents of the payload against the values we've + // just pulled out of the encrypted data blob. + err = ValidatePayloadWithBlinded(isFinalHop, payloadParsed) + if err != nil { + return nil, err + } + // Validate the data in the blinded route against our incoming htlc's + // information. + if err := ValidateBlindedRouteData( + routeData, b.IncomingAmount, b.IncomingCltv, + ); err != nil { + return nil, err + } + + fwdAmt, err := calculateForwardingAmount( + b.IncomingAmount, routeData.RelayInfo.Val.BaseFee, + routeData.RelayInfo.Val.FeeRate, + ) + if err != nil { + return nil, err + } + + // If we have an override for the blinding point for the next node, + // we'll just use it without tweaking (the sender intended to switch + // out directly for this blinding point). Otherwise, we'll tweak our + // blinding point to get the next ephemeral key. + nextEph, err := routeData.NextBlindingOverride.UnwrapOrFuncErr( + func() (tlv.RecordT[tlv.TlvType8, + *btcec.PublicKey], error) { + + next, err := b.Processor.NextEphemeral(blindingPoint) + if err != nil { + // Return a zero record because we expect the + // error to be checked. + return routeData.NextBlindingOverride.Zero(), + err + } + + return tlv.NewPrimitiveRecord[tlv.TlvType8](next), nil + }, + ) + if err != nil { + return nil, err + } + + return &ForwardingInfo{ + NextHop: routeData.ShortChannelID.Val, + AmountToForward: fwdAmt, + OutgoingCTLV: b.IncomingCltv - uint32( + routeData.RelayInfo.Val.CltvExpiryDelta, + ), + // Remap from blinding override type to blinding point type. + NextBlinding: tlv.SomeRecordT( + tlv.NewPrimitiveRecord[lnwire.BlindingPointTlvType]( + nextEph.Val), + ), + }, nil +} + +// calculateForwardingAmount calculates the amount to forward for a blinded +// hop based on the incoming amount and forwarding parameters. +// +// When forwarding a payment, the fee we take is calculated, not on the +// incoming amount, but rather on the amount we forward. We charge fees based +// on our own liquidity we are forwarding downstream. +// +// With route blinding, we are NOT given the amount to forward. This +// unintuitive looking formula comes from the fact that without the amount to +// forward, we cannot compute the fees taken directly. +// +// The amount to be forwarded can be computed as follows: +// +// amt_to_forward = incoming_amount - total_fees +// total_fees = base_fee + amt_to_forward*(fee_rate/1000000) +// +// Solving for amount_to_forward: +// amt_to_forward = incoming_amount - base_fee - (amount_to_forward * fee_rate)/1e6 +// amt_to_forward + (amount_to_forward * fee_rate) / 1e6 = incoming_amount - base_fee +// amt_to_forward * 1e6 + (amount_to_forward * fee_rate) = (incoming_amount - base_fee) * 1e6 +// amt_to_forward * (1e6 + fee_rate) = (incoming_amount - base_fee) * 1e6 +// amt_to_forward = ((incoming_amount - base_fee) * 1e6) / (1e6 + fee_rate) +// +// From there we use a ceiling formula for integer division so that we always +// round up, otherwise the sender may receive slightly less than intended: +// +// ceil(a/b) = (a + b - 1)/(b). +// +//nolint:lll,dupword +func calculateForwardingAmount(incomingAmount lnwire.MilliSatoshi, baseFee, + proportionalFee uint32) (lnwire.MilliSatoshi, error) { + + // Sanity check to prevent overflow. + if incomingAmount < lnwire.MilliSatoshi(baseFee) { + return 0, fmt.Errorf("incoming amount: %v < base fee: %v", + incomingAmount, baseFee) + } + numerator := (uint64(incomingAmount) - uint64(baseFee)) * 1e6 + denominator := 1e6 + uint64(proportionalFee) + + ceiling := (numerator + denominator - 1) / denominator + + return lnwire.MilliSatoshi(ceiling), nil +} + // OnionProcessor is responsible for keeping all sphinx dependent parts inside // and expose only decoding function. With such approach we give freedom for // subsystems which wants to decode sphinx path to not be dependable from @@ -147,11 +397,24 @@ func (p *OnionProcessor) Stop() error { return nil } -// ReconstructHopIterator attempts to decode a valid sphinx packet from the passed io.Reader -// instance using the rHash as the associated data when checking the relevant -// MACs during the decoding process. +// ReconstructBlindingInfo contains the information required to reconstruct a +// blinded onion. +type ReconstructBlindingInfo struct { + // BlindingKey is the blinding point set in UpdateAddHTLC. + BlindingKey lnwire.BlindingPointRecord + + // IncomingAmt is the amount for the incoming HTLC. + IncomingAmt lnwire.MilliSatoshi + + // IncomingExpiry is the expiry height of the incoming HTLC. + IncomingExpiry uint32 +} + +// ReconstructHopIterator attempts to decode a valid sphinx packet from the +// passed io.Reader instance using the rHash as the associated data when +// checking the relevant MACs during the decoding process. func (p *OnionProcessor) ReconstructHopIterator(r io.Reader, rHash []byte, - blindingPoint *btcec.PublicKey) (Iterator, error) { + blindingInfo ReconstructBlindingInfo) (Iterator, error) { onionPkt := &sphinx.OnionPacket{} if err := onionPkt.Decode(r); err != nil { @@ -159,9 +422,11 @@ func (p *OnionProcessor) ReconstructHopIterator(r io.Reader, rHash []byte, } var opts []sphinx.ProcessOnionOpt - if blindingPoint != nil { - opts = append(opts, sphinx.WithBlindingPoint(blindingPoint)) - } + blindingInfo.BlindingKey.WhenSome(func( + r tlv.RecordT[lnwire.BlindingPointTlvType, *btcec.PublicKey]) { + + opts = append(opts, sphinx.WithBlindingPoint(r.Val)) + }) // Attempt to process the Sphinx packet. We include the payment hash of // the HTLC as it's authenticated within the Sphinx packet itself as @@ -175,7 +440,12 @@ func (p *OnionProcessor) ReconstructHopIterator(r io.Reader, rHash []byte, return nil, err } - return makeSphinxHopIterator(onionPkt, sphinxPacket), nil + return makeSphinxHopIterator(onionPkt, sphinxPacket, BlindingKit{ + Processor: p.router, + UpdateAddBlinding: blindingInfo.BlindingKey, + IncomingAmount: blindingInfo.IncomingAmt, + IncomingCltv: blindingInfo.IncomingExpiry, + }), nil } // DecodeHopIteratorRequest encapsulates all date necessary to process an onion @@ -186,7 +456,7 @@ type DecodeHopIteratorRequest struct { RHash []byte IncomingCltv uint32 IncomingAmount lnwire.MilliSatoshi - BlindingPoint *btcec.PublicKey + BlindingPoint lnwire.BlindingPointRecord } // DecodeHopIteratorResponse encapsulates the outcome of a batched sphinx onion @@ -243,12 +513,14 @@ func (p *OnionProcessor) DecodeHopIterators(id []byte, } var opts []sphinx.ProcessOnionOpt - if req.BlindingPoint != nil { + req.BlindingPoint.WhenSome(func( + b tlv.RecordT[lnwire.BlindingPointTlvType, + *btcec.PublicKey]) { + opts = append(opts, sphinx.WithBlindingPoint( - req.BlindingPoint, + b.Val, )) - } - + }) err = tx.ProcessOnionPacket( seqNum, onionPkt, req.RHash, req.IncomingCltv, opts..., ) @@ -350,7 +622,14 @@ func (p *OnionProcessor) DecodeHopIterators(id []byte, // Finally, construct a hop iterator from our processed sphinx // packet, simultaneously caching the original onion packet. - resp.HopIterator = makeSphinxHopIterator(&onionPkts[i], &packets[i]) + resp.HopIterator = makeSphinxHopIterator( + &onionPkts[i], &packets[i], BlindingKit{ + Processor: p.router, + UpdateAddBlinding: reqs[i].BlindingPoint, + IncomingAmount: reqs[i].IncomingAmount, + IncomingCltv: reqs[i].IncomingCltv, + }, + ) } return resps, nil diff --git a/htlcswitch/hop/iterator_test.go b/htlcswitch/hop/iterator_test.go index cb2a2816f9e..60919333b37 100644 --- a/htlcswitch/hop/iterator_test.go +++ b/htlcswitch/hop/iterator_test.go @@ -3,8 +3,10 @@ package hop import ( "bytes" "encoding/binary" + "errors" "testing" + "github.com/btcsuite/btcd/btcec/v2" "github.com/davecgh/go-spew/spew" sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/lnwire" @@ -100,3 +102,201 @@ func TestSphinxHopIteratorForwardingInstructions(t *testing.T) { } } } + +// TestForwardingAmountCalc tests calculation of forwarding amounts from the +// hop's forwarding parameters. +func TestForwardingAmountCalc(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + incomingAmount lnwire.MilliSatoshi + baseFee uint32 + proportional uint32 + forwardAmount lnwire.MilliSatoshi + expectErr bool + }{ + { + name: "overflow", + incomingAmount: 10, + baseFee: 100, + expectErr: true, + }, + { + name: "trivial proportional", + incomingAmount: 100_000, + baseFee: 1000, + proportional: 10, + forwardAmount: 99000, + }, + { + name: "both fees charged", + incomingAmount: 10_002_020, + baseFee: 1000, + proportional: 1, + forwardAmount: 10_001_010, + }, + } + + for _, testCase := range tests { + testCase := testCase + + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + actual, err := calculateForwardingAmount( + testCase.incomingAmount, testCase.baseFee, + testCase.proportional, + ) + + require.Equal(t, testCase.expectErr, err != nil) + require.Equal(t, testCase.forwardAmount.ToSatoshis(), + actual.ToSatoshis()) + }) + } +} + +// mockProcessor is a mocked blinding point processor that just returns the +// data that it is called with when "decrypting". +type mockProcessor struct { + decryptErr error +} + +// DecryptBlindedHopData mocks blob decryption, returning the same data that +// it was called with and an optionally configured error. +func (m *mockProcessor) DecryptBlindedHopData(_ *btcec.PublicKey, + data []byte) ([]byte, error) { + + return data, m.decryptErr +} + +// NextEphemeral mocks getting our next ephemeral key. +func (m *mockProcessor) NextEphemeral(*btcec.PublicKey) (*btcec.PublicKey, + error) { + + return nil, nil +} + +// TestDecryptAndValidateFwdInfo tests deriving forwarding info using a +// blinding kit. This test does not cover assertions on the calculations of +// forwarding information, because this is covered in a test dedicated to those +// calculations. +func TestDecryptAndValidateFwdInfo(t *testing.T) { + t.Parallel() + + // Encode valid blinding data that we'll fake decrypting for our test. + maxCltv := 1000 + blindedData := record.NewBlindedRouteData( + lnwire.NewShortChanIDFromInt(1500), nil, + record.PaymentRelayInfo{ + CltvExpiryDelta: 10, + BaseFee: 100, + FeeRate: 0, + }, + &record.PaymentConstraints{ + MaxCltvExpiry: 1000, + HtlcMinimumMsat: lnwire.MilliSatoshi(1), + }, + nil, + ) + + validData, err := record.EncodeBlindedRouteData(blindedData) + require.NoError(t, err) + + // Mocked error. + errDecryptFailed := errors.New("could not decrypt") + + tests := []struct { + name string + data []byte + incomingCLTV uint32 + updateAddBlinding *btcec.PublicKey + payloadBlinding *btcec.PublicKey + processor *mockProcessor + expectedErr error + }{ + { + name: "no blinding point", + data: validData, + processor: &mockProcessor{}, + expectedErr: ErrInvalidPayload{ + Type: record.BlindingPointOnionType, + Violation: OmittedViolation, + }, + }, + { + name: "both blinding points", + data: validData, + updateAddBlinding: &btcec.PublicKey{}, + payloadBlinding: &btcec.PublicKey{}, + processor: &mockProcessor{}, + expectedErr: ErrInvalidPayload{ + Type: record.BlindingPointOnionType, + Violation: IncludedViolation, + }, + }, + { + name: "decryption failed", + data: validData, + updateAddBlinding: &btcec.PublicKey{}, + incomingCLTV: 500, + processor: &mockProcessor{ + decryptErr: errDecryptFailed, + }, + expectedErr: errDecryptFailed, + }, + { + name: "decode fails", + data: []byte{1, 2, 3}, + updateAddBlinding: &btcec.PublicKey{}, + incomingCLTV: 500, + processor: &mockProcessor{}, + expectedErr: ErrDecodeFailed, + }, + { + name: "validation fails", + data: validData, + updateAddBlinding: &btcec.PublicKey{}, + incomingCLTV: uint32(maxCltv) + 10, + processor: &mockProcessor{}, + expectedErr: ErrInvalidPayload{ + Type: record.LockTimeOnionType, + Violation: InsufficientViolation, + }, + }, + { + name: "valid", + updateAddBlinding: &btcec.PublicKey{}, + data: validData, + processor: &mockProcessor{}, + expectedErr: nil, + }, + } + + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + // We don't actually use blinding keys due to our + // mocking so they can be nil. + kit := BlindingKit{ + Processor: testCase.processor, + IncomingAmount: 10000, + IncomingCltv: testCase.incomingCLTV, + } + + if testCase.updateAddBlinding != nil { + kit.UpdateAddBlinding = tlv.SomeRecordT( + //nolint:lll + tlv.NewPrimitiveRecord[lnwire.BlindingPointTlvType](testCase.updateAddBlinding), + ) + } + _, err := kit.DecryptAndValidateFwdInfo( + &Payload{ + encryptedData: testCase.data, + blindingPoint: testCase.payloadBlinding, + }, false, + make(map[tlv.Type][]byte), + ) + require.ErrorIs(t, err, testCase.expectedErr) + }) + } +} diff --git a/htlcswitch/hop/payload.go b/htlcswitch/hop/payload.go index cbd8d08a578..70fdb1403fc 100644 --- a/htlcswitch/hop/payload.go +++ b/htlcswitch/hop/payload.go @@ -133,11 +133,14 @@ func NewLegacyPayload(f *sphinx.HopData) *Payload { } } -// NewPayloadFromReader builds a new Hop from the passed io.Reader. The reader +// NewPayloadFromReader builds a new Hop from the passed io.Reader and returns +// a map of all the types that were found in the payload. The reader // should correspond to the bytes encapsulated in a TLV onion payload. The // final hop bool signals that this payload was the final packet parsed by // sphinx. -func NewPayloadFromReader(r io.Reader, finalHop bool) (*Payload, error) { +func NewPayloadFromReader(r io.Reader, finalHop bool) (*Payload, + map[tlv.Type][]byte, error) { + var ( cid uint64 amt uint64 @@ -162,27 +165,27 @@ func NewPayloadFromReader(r io.Reader, finalHop bool) (*Payload, error) { record.NewTotalAmtMsatBlinded(&totalAmtMsat), ) if err != nil { - return nil, err + return nil, nil, err } // Since this data is provided by a potentially malicious peer, pass it // into the P2P decoding variant. parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r) if err != nil { - return nil, err + return nil, nil, err } // Validate whether the sender properly included or omitted tlv records // in accordance with BOLT 04. err = ValidateParsedPayloadTypes(parsedTypes, finalHop) if err != nil { - return nil, err + return nil, nil, err } // Check for violation of the rules for mandatory fields. violatingType := getMinRequiredViolation(parsedTypes) if violatingType != nil { - return nil, ErrInvalidPayload{ + return nil, nil, ErrInvalidPayload{ Type: *violatingType, Violation: RequiredViolation, FinalHop: finalHop, @@ -229,7 +232,7 @@ func NewPayloadFromReader(r io.Reader, finalHop bool) (*Payload, error) { blindingPoint: blindingPoint, customRecords: customRecords, totalAmtMsat: lnwire.MilliSatoshi(totalAmtMsat), - }, nil + }, nil, nil } // ForwardingInfo returns the basic parameters required for HTLC forwarding, @@ -484,3 +487,37 @@ func ValidateBlindedRouteData(blindedData *record.BlindedRouteData, return nil } + +// ValidatePayloadWithBlinded validates a payload against the contents of +// its encrypted data blob. +func ValidatePayloadWithBlinded(isFinalHop bool, + payloadParsed map[tlv.Type][]byte) error { + + // Blinded routes restrict the presence of TLVs more strictly than + // regular routes, check that intermediate and final hops only have + // the TLVs the spec allows them to have. + allowedTLVs := map[tlv.Type]bool{ + record.EncryptedDataOnionType: true, + record.BlindingPointOnionType: true, + } + + if isFinalHop { + allowedTLVs[record.AmtOnionType] = true + allowedTLVs[record.LockTimeOnionType] = true + allowedTLVs[record.TotalAmtMsatBlindedType] = true + } + + for tlvType := range payloadParsed { + if _, ok := allowedTLVs[tlvType]; ok { + continue + } + + return ErrInvalidPayload{ + Type: tlvType, + Violation: IncludedViolation, + FinalHop: isFinalHop, + } + } + + return nil +} diff --git a/htlcswitch/hop/payload_test.go b/htlcswitch/hop/payload_test.go index 148b806f96f..301e5771660 100644 --- a/htlcswitch/hop/payload_test.go +++ b/htlcswitch/hop/payload_test.go @@ -10,6 +10,7 @@ import ( "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/record" + "github.com/lightningnetwork/lnd/tlv" "github.com/stretchr/testify/require" ) @@ -478,7 +479,7 @@ func testDecodeHopPayloadValidation(t *testing.T, test decodePayloadTest) { testChildIndex = uint32(9) ) - p, err := hop.NewPayloadFromReader( + p, _, err := hop.NewPayloadFromReader( bytes.NewReader(test.payload), test.isFinalHop, ) if !reflect.DeepEqual(test.expErr, err) { @@ -695,3 +696,67 @@ func TestValidateBlindedRouteData(t *testing.T) { }) } } + +// TestValidatePayloadWithBlinded tests validation of the contents of a +// payload when it's for a blinded payment. +func TestValidatePayloadWithBlinded(t *testing.T) { + t.Parallel() + + finalHopMap := map[tlv.Type][]byte{ + record.AmtOnionType: nil, + record.LockTimeOnionType: nil, + record.TotalAmtMsatBlindedType: nil, + } + + tests := []struct { + name string + isFinal bool + parsed map[tlv.Type][]byte + err bool + }{ + { + name: "final hop, valid", + isFinal: true, + parsed: finalHopMap, + }, + { + name: "intermediate hop, invalid", + isFinal: false, + parsed: finalHopMap, + err: true, + }, + { + name: "intermediate hop, invalid", + isFinal: false, + parsed: map[tlv.Type][]byte{ + record.EncryptedDataOnionType: nil, + record.BlindingPointOnionType: nil, + }, + }, + { + name: "unknown record, invalid", + isFinal: false, + parsed: map[tlv.Type][]byte{ + tlv.Type(99): nil, + }, + err: true, + }, + } + + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + err := hop.ValidatePayloadWithBlinded( + testCase.isFinal, testCase.parsed, + ) + + // We can't determine our exact error because we + // iterate through a map (non-deterministic) in the + // function. + if testCase.err { + require.NotNil(t, err) + } else { + require.Nil(t, err) + } + }) + } +} diff --git a/htlcswitch/link.go b/htlcswitch/link.go index 7e1dded1dda..778e78d7008 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -273,6 +273,11 @@ type ChannelLinkConfig struct { // re-establish and should not allow anymore HTLC adds on the outgoing // direction of the link. PreviouslySentShutdown fn.Option[lnwire.Shutdown] + + // Adds the option to disable forwarding payments in blinded routes + // by failing back any blinding-related payloads as if they were + // invalid. + DisallowRouteBlinding bool } // channelLink is the service which drives a channel's commitment update @@ -1928,6 +1933,19 @@ func (l *channelLink) handleUpstreamMsg(msg lnwire.Message) { return } + // Disallow htlcs with blinding points set if we haven't + // enabled the feature. This saves us from having to process + // the onion at all, but will only catch blinded payments + // where we are a relaying node (as the blinding point will + // be in the payload when we're the introduction node). + if msg.BlindingPoint.IsSome() && l.cfg.DisallowRouteBlinding { + l.fail(LinkFailureError{code: ErrInvalidUpdate}, + "blinding point included when route blinding "+ + "is disabled") + + return + } + // We just received an add request from an upstream peer, so we // add it to our state machine, then add the HTLC to our // "settle" list in the event that we know the preimage. @@ -3291,6 +3309,27 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg, fwdInfo := pld.ForwardingInfo() + // Check whether the payload we've just processed uses our + // node as the introduction point (gave us a blinding key in + // the payload itself) and fail it back if we don't support + // route blinding. + if fwdInfo.NextBlinding.IsSome() && + l.cfg.DisallowRouteBlinding { + + failure := lnwire.NewInvalidBlinding( + onionBlob[:], + ) + l.sendHTLCError( + pd, NewLinkError(failure), obfuscator, false, + ) + + l.log.Error("rejected htlc that uses use as an " + + "introduction point when we do not support " + + "route blinding") + + continue + } + switch fwdInfo.NextHop { case hop.Exit: err := l.processExitHop( @@ -3330,9 +3369,10 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg, // Otherwise, it was already processed, we can // can collect it and continue. addMsg := &lnwire.UpdateAddHTLC{ - Expiry: fwdInfo.OutgoingCTLV, - Amount: fwdInfo.AmountToForward, - PaymentHash: pd.RHash, + Expiry: fwdInfo.OutgoingCTLV, + Amount: fwdInfo.AmountToForward, + PaymentHash: pd.RHash, + BlindingPoint: fwdInfo.NextBlinding, } // Finally, we'll encode the onion packet for @@ -3375,9 +3415,10 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg, // create the outgoing HTLC using the parameters as // specified in the forwarding info. addMsg := &lnwire.UpdateAddHTLC{ - Expiry: fwdInfo.OutgoingCTLV, - Amount: fwdInfo.AmountToForward, - PaymentHash: pd.RHash, + Expiry: fwdInfo.OutgoingCTLV, + Amount: fwdInfo.AmountToForward, + PaymentHash: pd.RHash, + BlindingPoint: fwdInfo.NextBlinding, } // Finally, we'll encode the onion packet for the diff --git a/itest/list_on_test.go b/itest/list_on_test.go index 7a618402f8f..f78601a1040 100644 --- a/itest/list_on_test.go +++ b/itest/list_on_test.go @@ -558,6 +558,10 @@ var allTestCases = []*lntest.TestCase{ Name: "query blinded route", TestFunc: testQueryBlindedRoutes, }, + { + Name: "forward blinded", + TestFunc: testForwardBlindedRoute, + }, { Name: "removetx", TestFunc: testRemoveTx, diff --git a/itest/lnd_route_blinding.go b/itest/lnd_route_blinding.go deleted file mode 100644 index 2104cb1c830..00000000000 --- a/itest/lnd_route_blinding.go +++ /dev/null @@ -1,312 +0,0 @@ -package itest - -import ( - "crypto/sha256" - "encoding/hex" - - "github.com/btcsuite/btcd/btcec/v2" - "github.com/btcsuite/btcd/btcutil" - "github.com/lightningnetwork/lnd/chainreg" - "github.com/lightningnetwork/lnd/lnrpc" - "github.com/lightningnetwork/lnd/lnrpc/routerrpc" - "github.com/lightningnetwork/lnd/lntest" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// testQueryBlindedRoutes tests querying routes to blinded routes. To do this, -// it sets up a nework of Alice - Bob - Carol and creates a mock blinded route -// that uses Carol as the introduction node (plus dummy hops to cover multiple -// hops). The test simply asserts that the structure of the route is as -// expected. It also includes the edge case of a single-hop blinded route, -// which indicates that the introduction node is the recipient. -func testQueryBlindedRoutes(ht *lntest.HarnessTest) { - var ( - // Convenience aliases. - alice = ht.Alice - bob = ht.Bob - ) - - // Setup a two hop channel network: Alice -- Bob -- Carol. - // We set our proportional fee for these channels to zero, so that - // our calculations are easier. This is okay, because we're not testing - // the basic mechanics of pathfinding in this test. - chanAmt := btcutil.Amount(100000) - chanPointAliceBob := ht.OpenChannel( - alice, bob, lntest.OpenChannelParams{ - Amt: chanAmt, - BaseFee: 10000, - FeeRate: 0, - UseBaseFee: true, - UseFeeRate: true, - }, - ) - - carol := ht.NewNode("Carol", nil) - ht.EnsureConnected(bob, carol) - - var bobCarolBase uint64 = 2000 - chanPointBobCarol := ht.OpenChannel( - bob, carol, lntest.OpenChannelParams{ - Amt: chanAmt, - BaseFee: bobCarolBase, - FeeRate: 0, - UseBaseFee: true, - UseFeeRate: true, - }, - ) - - // Wait for Alice to see Bob/Carol's channel because she'll need it for - // pathfinding. - ht.AssertTopologyChannelOpen(alice, chanPointBobCarol) - - // Lookup full channel info so that we have channel ids for our route. - aliceBobChan := ht.GetChannelByChanPoint(alice, chanPointAliceBob) - bobCarolChan := ht.GetChannelByChanPoint(bob, chanPointBobCarol) - - // Sanity check that bob's fee is as expected. - chanInfoReq := &lnrpc.ChanInfoRequest{ - ChanId: bobCarolChan.ChanId, - } - - bobCarolInfo := bob.RPC.GetChanInfo(chanInfoReq) - - // Our test relies on knowing the fee rate for bob - carol to set the - // fees we expect for our route. Perform a quick sanity check that our - // policy is as expected. - var policy *lnrpc.RoutingPolicy - if bobCarolInfo.Node1Pub == bob.PubKeyStr { - policy = bobCarolInfo.Node1Policy - } else { - policy = bobCarolInfo.Node2Policy - } - require.Equal(ht, bobCarolBase, uint64(policy.FeeBaseMsat), "base fee") - require.EqualValues(ht, 0, policy.FeeRateMilliMsat, "fee rate") - - // We'll also need the current block height to calculate our locktimes. - info := alice.RPC.GetInfo() - - // Since we created channels with default parameters, we can assume - // that all of our channels have the default cltv delta. - bobCarolDelta := uint32(chainreg.DefaultBitcoinTimeLockDelta) - - // Create arbitrary pubkeys for use in our blinded route. They're not - // actually used functionally in this test, so we can just make them up. - var ( - _, blindingPoint = btcec.PrivKeyFromBytes([]byte{1}) - _, carolBlinded = btcec.PrivKeyFromBytes([]byte{2}) - _, blindedHop1 = btcec.PrivKeyFromBytes([]byte{3}) - _, blindedHop2 = btcec.PrivKeyFromBytes([]byte{4}) - - encryptedDataCarol = []byte{1, 2, 3} - encryptedData1 = []byte{4, 5, 6} - encryptedData2 = []byte{7, 8, 9} - - blindingBytes = blindingPoint.SerializeCompressed() - carolBlindedBytes = carolBlinded.SerializeCompressed() - blinded1Bytes = blindedHop1.SerializeCompressed() - blinded2Bytes = blindedHop2.SerializeCompressed() - ) - - // Now we create a blinded route which uses carol as an introduction - // node followed by two dummy hops (the arbitrary pubkeys in our - // blinded route above: - // Carol --- B1 --- B2 - route := &lnrpc.BlindedPath{ - IntroductionNode: carol.PubKey[:], - BlindingPoint: blindingBytes, - BlindedHops: []*lnrpc.BlindedHop{ - { - // The first hop in the blinded route is - // expected to be the introduction node. - BlindedNode: carolBlindedBytes, - EncryptedData: encryptedDataCarol, - }, - { - BlindedNode: blinded1Bytes, - EncryptedData: encryptedData1, - }, - { - BlindedNode: blinded2Bytes, - EncryptedData: encryptedData2, - }, - }, - } - - // Create a blinded payment that has aggregate cltv and fee params - // for our route. - var ( - blindedBaseFee uint64 = 1500 - blindedCltvDelta uint32 = 125 - ) - - blindedPayment := &lnrpc.BlindedPaymentPath{ - BlindedPath: route, - BaseFeeMsat: blindedBaseFee, - TotalCltvDelta: blindedCltvDelta, - } - - // Query for a route to the blinded path constructed above. - var paymentAmt int64 = 100_000 - - req := &lnrpc.QueryRoutesRequest{ - AmtMsat: paymentAmt, - BlindedPaymentPaths: []*lnrpc.BlindedPaymentPath{ - blindedPayment, - }, - } - - resp := alice.RPC.QueryRoutes(req) - require.Len(ht, resp.Routes, 1) - - // Payment amount and cltv will be included for the bob/carol edge - // (because we apply on the outgoing hop), and the blinded portion of - // the route. - totalFee := bobCarolBase + blindedBaseFee - totalAmt := uint64(paymentAmt) + totalFee - totalCltv := info.BlockHeight + bobCarolDelta + blindedCltvDelta - - // Alice -> Bob - // Forward: total - bob carol fees - // Expiry: total - bob carol delta - // - // Bob -> Carol - // Forward: 101500 (total + blinded fees) - // Expiry: Height + blinded cltv delta - // Encrypted Data: enc_carol - // - // Carol -> Blinded 1 - // Forward/ Expiry: 0 - // Encrypted Data: enc_1 - // - // Blinded 1 -> Blinded 2 - // Forward/ Expiry: Height - // Encrypted Data: enc_2 - hop0Amount := int64(totalAmt - bobCarolBase) - hop0Expiry := totalCltv - bobCarolDelta - finalHopExpiry := totalCltv - bobCarolDelta - blindedCltvDelta - - expectedRoute := &lnrpc.Route{ - TotalTimeLock: totalCltv, - TotalAmtMsat: int64(totalAmt), - TotalFeesMsat: int64(totalFee), - Hops: []*lnrpc.Hop{ - { - ChanId: aliceBobChan.ChanId, - Expiry: hop0Expiry, - AmtToForwardMsat: hop0Amount, - FeeMsat: int64(bobCarolBase), - PubKey: bob.PubKeyStr, - }, - { - ChanId: bobCarolChan.ChanId, - PubKey: carol.PubKeyStr, - BlindingPoint: blindingBytes, - FeeMsat: int64(blindedBaseFee), - EncryptedData: encryptedDataCarol, - }, - { - PubKey: hex.EncodeToString( - blinded1Bytes, - ), - EncryptedData: encryptedData1, - }, - { - PubKey: hex.EncodeToString( - blinded2Bytes, - ), - AmtToForwardMsat: paymentAmt, - Expiry: finalHopExpiry, - EncryptedData: encryptedData2, - TotalAmtMsat: uint64(paymentAmt), - }, - }, - } - - r := resp.Routes[0] - assert.Equal(ht, expectedRoute.TotalTimeLock, r.TotalTimeLock) - assert.Equal(ht, expectedRoute.TotalAmtMsat, r.TotalAmtMsat) - assert.Equal(ht, expectedRoute.TotalFeesMsat, r.TotalFeesMsat) - - assert.Equal(ht, len(expectedRoute.Hops), len(r.Hops)) - for i, hop := range expectedRoute.Hops { - assert.Equal(ht, hop.PubKey, r.Hops[i].PubKey, - "hop: %v pubkey", i) - - assert.Equal(ht, hop.ChanId, r.Hops[i].ChanId, - "hop: %v chan id", i) - - assert.Equal(ht, hop.Expiry, r.Hops[i].Expiry, - "hop: %v expiry", i) - - assert.Equal(ht, hop.AmtToForwardMsat, - r.Hops[i].AmtToForwardMsat, "hop: %v forward", i) - - assert.Equal(ht, hop.FeeMsat, r.Hops[i].FeeMsat, - "hop: %v fee", i) - - assert.Equal(ht, hop.BlindingPoint, r.Hops[i].BlindingPoint, - "hop: %v blinding point", i) - - assert.Equal(ht, hop.EncryptedData, r.Hops[i].EncryptedData, - "hop: %v encrypted data", i) - } - - // Dispatch a payment to our blinded route. - preimage := [33]byte{1, 2, 3} - hash := sha256.Sum256(preimage[:]) - - sendReq := &routerrpc.SendToRouteRequest{ - PaymentHash: hash[:], - Route: r, - } - - htlcAttempt := alice.RPC.SendToRouteV2(sendReq) - - // Since Carol doesn't understand blinded routes, we expect her to fail - // the payment because the onion payload is invalid (missing amount to - // forward). - require.NotNil(ht, htlcAttempt.Failure) - require.Equal(ht, uint32(2), htlcAttempt.Failure.FailureSourceIndex) - - // Next, we test an edge case where just an introduction node is - // included as a "single hop blinded route". - sendToIntroCLTVFinal := uint32(15) - sendToIntroTimelock := info.BlockHeight + bobCarolDelta + - sendToIntroCLTVFinal - - introNodeBlinded := &lnrpc.BlindedPaymentPath{ - BlindedPath: &lnrpc.BlindedPath{ - IntroductionNode: carol.PubKey[:], - BlindingPoint: blindingBytes, - BlindedHops: []*lnrpc.BlindedHop{ - { - // The first hop in the blinded route is - // expected to be the introduction node. - BlindedNode: carolBlindedBytes, - EncryptedData: encryptedDataCarol, - }, - }, - }, - // Fees should be zero for a single hop blinded path, and the - // total cltv expiry is just expected to cover the final cltv - // delta of the receiving node (ie, the introduction node). - BaseFeeMsat: 0, - TotalCltvDelta: sendToIntroCLTVFinal, - } - req = &lnrpc.QueryRoutesRequest{ - AmtMsat: paymentAmt, - BlindedPaymentPaths: []*lnrpc.BlindedPaymentPath{ - introNodeBlinded, - }, - } - - // Assert that we have one route, and two hops: Alice/Bob and Bob/Carol. - resp = alice.RPC.QueryRoutes(req) - require.Len(ht, resp.Routes, 1) - require.Len(ht, resp.Routes[0].Hops, 2) - require.Equal(ht, resp.Routes[0].TotalTimeLock, sendToIntroTimelock) - - ht.CloseChannel(alice, chanPointAliceBob) - ht.CloseChannel(bob, chanPointBobCarol) -} diff --git a/itest/lnd_route_blinding_test.go b/itest/lnd_route_blinding_test.go new file mode 100644 index 00000000000..430b001c6d9 --- /dev/null +++ b/itest/lnd_route_blinding_test.go @@ -0,0 +1,790 @@ +package itest + +import ( + "bytes" + "context" + "crypto/sha256" + "encoding/hex" + "time" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/btcsuite/btcd/btcutil" + sphinx "github.com/lightningnetwork/lightning-onion" + "github.com/lightningnetwork/lnd/chainreg" + "github.com/lightningnetwork/lnd/lnrpc" + "github.com/lightningnetwork/lnd/lnrpc/routerrpc" + "github.com/lightningnetwork/lnd/lntest" + "github.com/lightningnetwork/lnd/lntest/node" + "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/record" + "github.com/lightningnetwork/lnd/routing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// testQueryBlindedRoutes tests querying routes to blinded routes. To do this, +// it sets up a nework of Alice - Bob - Carol and creates a mock blinded route +// that uses Carol as the introduction node (plus dummy hops to cover multiple +// hops). The test simply asserts that the structure of the route is as +// expected. It also includes the edge case of a single-hop blinded route, +// which indicates that the introduction node is the recipient. +func testQueryBlindedRoutes(ht *lntest.HarnessTest) { + var ( + // Convenience aliases. + alice = ht.Alice + bob = ht.Bob + ) + + // Setup a two hop channel network: Alice -- Bob -- Carol. + // We set our proportional fee for these channels to zero, so that + // our calculations are easier. This is okay, because we're not testing + // the basic mechanics of pathfinding in this test. + chanAmt := btcutil.Amount(100000) + chanPointAliceBob := ht.OpenChannel( + alice, bob, lntest.OpenChannelParams{ + Amt: chanAmt, + BaseFee: 10000, + FeeRate: 0, + UseBaseFee: true, + UseFeeRate: true, + }, + ) + + carol := ht.NewNode("Carol", nil) + ht.EnsureConnected(bob, carol) + + var bobCarolBase uint64 = 2000 + chanPointBobCarol := ht.OpenChannel( + bob, carol, lntest.OpenChannelParams{ + Amt: chanAmt, + BaseFee: bobCarolBase, + FeeRate: 0, + UseBaseFee: true, + UseFeeRate: true, + }, + ) + + // Wait for Alice to see Bob/Carol's channel because she'll need it for + // pathfinding. + ht.AssertTopologyChannelOpen(alice, chanPointBobCarol) + + // Lookup full channel info so that we have channel ids for our route. + aliceBobChan := ht.GetChannelByChanPoint(alice, chanPointAliceBob) + bobCarolChan := ht.GetChannelByChanPoint(bob, chanPointBobCarol) + + // Sanity check that bob's fee is as expected. + chanInfoReq := &lnrpc.ChanInfoRequest{ + ChanId: bobCarolChan.ChanId, + } + + bobCarolInfo := bob.RPC.GetChanInfo(chanInfoReq) + + // Our test relies on knowing the fee rate for bob - carol to set the + // fees we expect for our route. Perform a quick sanity check that our + // policy is as expected. + var policy *lnrpc.RoutingPolicy + if bobCarolInfo.Node1Pub == bob.PubKeyStr { + policy = bobCarolInfo.Node1Policy + } else { + policy = bobCarolInfo.Node2Policy + } + require.Equal(ht, bobCarolBase, uint64(policy.FeeBaseMsat), "base fee") + require.EqualValues(ht, 0, policy.FeeRateMilliMsat, "fee rate") + + // We'll also need the current block height to calculate our locktimes. + info := alice.RPC.GetInfo() + + // Since we created channels with default parameters, we can assume + // that all of our channels have the default cltv delta. + bobCarolDelta := uint32(chainreg.DefaultBitcoinTimeLockDelta) + + // Create arbitrary pubkeys for use in our blinded route. They're not + // actually used functionally in this test, so we can just make them up. + var ( + _, blindingPoint = btcec.PrivKeyFromBytes([]byte{1}) + _, carolBlinded = btcec.PrivKeyFromBytes([]byte{2}) + _, blindedHop1 = btcec.PrivKeyFromBytes([]byte{3}) + _, blindedHop2 = btcec.PrivKeyFromBytes([]byte{4}) + + encryptedDataCarol = []byte{1, 2, 3} + encryptedData1 = []byte{4, 5, 6} + encryptedData2 = []byte{7, 8, 9} + + blindingBytes = blindingPoint.SerializeCompressed() + carolBlindedBytes = carolBlinded.SerializeCompressed() + blinded1Bytes = blindedHop1.SerializeCompressed() + blinded2Bytes = blindedHop2.SerializeCompressed() + ) + + // Now we create a blinded route which uses carol as an introduction + // node followed by two dummy hops (the arbitrary pubkeys in our + // blinded route above: + // Carol --- B1 --- B2 + route := &lnrpc.BlindedPath{ + IntroductionNode: carol.PubKey[:], + BlindingPoint: blindingBytes, + BlindedHops: []*lnrpc.BlindedHop{ + { + // The first hop in the blinded route is + // expected to be the introduction node. + BlindedNode: carolBlindedBytes, + EncryptedData: encryptedDataCarol, + }, + { + BlindedNode: blinded1Bytes, + EncryptedData: encryptedData1, + }, + { + BlindedNode: blinded2Bytes, + EncryptedData: encryptedData2, + }, + }, + } + + // Create a blinded payment that has aggregate cltv and fee params + // for our route. + var ( + blindedBaseFee uint64 = 1500 + blindedCltvDelta uint32 = 125 + ) + + blindedPayment := &lnrpc.BlindedPaymentPath{ + BlindedPath: route, + BaseFeeMsat: blindedBaseFee, + TotalCltvDelta: blindedCltvDelta, + } + + // Query for a route to the blinded path constructed above. + var paymentAmt int64 = 100_000 + + req := &lnrpc.QueryRoutesRequest{ + AmtMsat: paymentAmt, + BlindedPaymentPaths: []*lnrpc.BlindedPaymentPath{ + blindedPayment, + }, + } + + resp := alice.RPC.QueryRoutes(req) + require.Len(ht, resp.Routes, 1) + + // Payment amount and cltv will be included for the bob/carol edge + // (because we apply on the outgoing hop), and the blinded portion of + // the route. + totalFee := bobCarolBase + blindedBaseFee + totalAmt := uint64(paymentAmt) + totalFee + totalCltv := info.BlockHeight + bobCarolDelta + blindedCltvDelta + + // Alice -> Bob + // Forward: total - bob carol fees + // Expiry: total - bob carol delta + // + // Bob -> Carol + // Forward: 101500 (total + blinded fees) + // Expiry: Height + blinded cltv delta + // Encrypted Data: enc_carol + // + // Carol -> Blinded 1 + // Forward/ Expiry: 0 + // Encrypted Data: enc_1 + // + // Blinded 1 -> Blinded 2 + // Forward/ Expiry: Height + // Encrypted Data: enc_2 + hop0Amount := int64(totalAmt - bobCarolBase) + hop0Expiry := totalCltv - bobCarolDelta + finalHopExpiry := totalCltv - bobCarolDelta - blindedCltvDelta + + expectedRoute := &lnrpc.Route{ + TotalTimeLock: totalCltv, + TotalAmtMsat: int64(totalAmt), + TotalFeesMsat: int64(totalFee), + Hops: []*lnrpc.Hop{ + { + ChanId: aliceBobChan.ChanId, + Expiry: hop0Expiry, + AmtToForwardMsat: hop0Amount, + FeeMsat: int64(bobCarolBase), + PubKey: bob.PubKeyStr, + }, + { + ChanId: bobCarolChan.ChanId, + PubKey: carol.PubKeyStr, + BlindingPoint: blindingBytes, + FeeMsat: int64(blindedBaseFee), + EncryptedData: encryptedDataCarol, + }, + { + PubKey: hex.EncodeToString( + blinded1Bytes, + ), + EncryptedData: encryptedData1, + }, + { + PubKey: hex.EncodeToString( + blinded2Bytes, + ), + AmtToForwardMsat: paymentAmt, + Expiry: finalHopExpiry, + EncryptedData: encryptedData2, + TotalAmtMsat: uint64(paymentAmt), + }, + }, + } + + r := resp.Routes[0] + assert.Equal(ht, expectedRoute.TotalTimeLock, r.TotalTimeLock) + assert.Equal(ht, expectedRoute.TotalAmtMsat, r.TotalAmtMsat) + assert.Equal(ht, expectedRoute.TotalFeesMsat, r.TotalFeesMsat) + + assert.Equal(ht, len(expectedRoute.Hops), len(r.Hops)) + for i, hop := range expectedRoute.Hops { + assert.Equal(ht, hop.PubKey, r.Hops[i].PubKey, + "hop: %v pubkey", i) + + assert.Equal(ht, hop.ChanId, r.Hops[i].ChanId, + "hop: %v chan id", i) + + assert.Equal(ht, hop.Expiry, r.Hops[i].Expiry, + "hop: %v expiry", i) + + assert.Equal(ht, hop.AmtToForwardMsat, + r.Hops[i].AmtToForwardMsat, "hop: %v forward", i) + + assert.Equal(ht, hop.FeeMsat, r.Hops[i].FeeMsat, + "hop: %v fee", i) + + assert.Equal(ht, hop.BlindingPoint, r.Hops[i].BlindingPoint, + "hop: %v blinding point", i) + + assert.Equal(ht, hop.EncryptedData, r.Hops[i].EncryptedData, + "hop: %v encrypted data", i) + } + + // Dispatch a payment to our blinded route. + preimage := [33]byte{1, 2, 3} + hash := sha256.Sum256(preimage[:]) + + sendReq := &routerrpc.SendToRouteRequest{ + PaymentHash: hash[:], + Route: r, + } + + htlcAttempt := alice.RPC.SendToRouteV2(sendReq) + + // Since Carol doesn't understand blinded routes, we expect her to fail + // the payment because the onion payload is invalid (missing amount to + // forward). + require.NotNil(ht, htlcAttempt.Failure) + require.Equal(ht, uint32(2), htlcAttempt.Failure.FailureSourceIndex) + + // Next, we test an edge case where just an introduction node is + // included as a "single hop blinded route". + sendToIntroCLTVFinal := uint32(15) + sendToIntroTimelock := info.BlockHeight + bobCarolDelta + + sendToIntroCLTVFinal + + introNodeBlinded := &lnrpc.BlindedPaymentPath{ + BlindedPath: &lnrpc.BlindedPath{ + IntroductionNode: carol.PubKey[:], + BlindingPoint: blindingBytes, + BlindedHops: []*lnrpc.BlindedHop{ + { + // The first hop in the blinded route is + // expected to be the introduction node. + BlindedNode: carolBlindedBytes, + EncryptedData: encryptedDataCarol, + }, + }, + }, + // Fees should be zero for a single hop blinded path, and the + // total cltv expiry is just expected to cover the final cltv + // delta of the receiving node (ie, the introduction node). + BaseFeeMsat: 0, + TotalCltvDelta: sendToIntroCLTVFinal, + } + req = &lnrpc.QueryRoutesRequest{ + AmtMsat: paymentAmt, + BlindedPaymentPaths: []*lnrpc.BlindedPaymentPath{ + introNodeBlinded, + }, + } + + // Assert that we have one route, and two hops: Alice/Bob and Bob/Carol. + resp = alice.RPC.QueryRoutes(req) + require.Len(ht, resp.Routes, 1) + require.Len(ht, resp.Routes[0].Hops, 2) + require.Equal(ht, resp.Routes[0].TotalTimeLock, sendToIntroTimelock) + + ht.CloseChannel(alice, chanPointAliceBob) + ht.CloseChannel(bob, chanPointBobCarol) +} + +type blindedForwardTest struct { + ht *lntest.HarnessTest + carol *node.HarnessNode + dave *node.HarnessNode + channels []*lnrpc.ChannelPoint + + carolInterceptor routerrpc.Router_HtlcInterceptorClient + + preimage [32]byte + + // cancel will cancel the test's top level context. + cancel func() +} + +func newBlindedForwardTest(ht *lntest.HarnessTest) (context.Context, + *blindedForwardTest) { + + ctx, cancel := context.WithCancel(context.Background()) + + return ctx, &blindedForwardTest{ + ht: ht, + cancel: cancel, + preimage: [32]byte{1, 2, 3}, + } +} + +// setup spins up additional nodes needed for our test and creates a four hop +// network for testing blinded forwarding and returns a blinded route from +// Bob -> Carol -> Dave, with Bob acting as the introduction point and an +// interceptor on Carol's node to manage HTLCs (as Dave does not yet support +// receiving). +func (b *blindedForwardTest) setup( + ctx context.Context) *routing.BlindedPayment { + + b.carol = b.ht.NewNode("Carol", []string{ + "requireinterceptor", + }) + + var err error + b.carolInterceptor, err = b.carol.RPC.Router.HtlcInterceptor(ctx) + require.NoError(b.ht, err, "interceptor") + + b.dave = b.ht.NewNode("Dave", nil) + + b.channels = setupFourHopNetwork(b.ht, b.carol, b.dave) + + // Create a blinded route to Dave via Bob --- Carol --- Dave: + bobChan := b.ht.GetChannelByChanPoint(b.ht.Bob, b.channels[1]) + carolChan := b.ht.GetChannelByChanPoint(b.carol, b.channels[2]) + + edges := []*forwardingEdge{ + getForwardingEdge(b.ht, b.ht.Bob, bobChan.ChanId), + getForwardingEdge(b.ht, b.carol, carolChan.ChanId), + } + + davePk, err := btcec.ParsePubKey(b.dave.PubKey[:]) + require.NoError(b.ht, err, "dave pubkey") + + return b.createBlindedRoute(edges, davePk, 50) +} + +// cleanup tears down all channels created by the test and cancels the top +// level context used in the test. +func (b *blindedForwardTest) cleanup() { + b.ht.CloseChannel(b.ht.Alice, b.channels[0]) + b.ht.CloseChannel(b.ht.Bob, b.channels[1]) + b.ht.CloseChannel(b.carol, b.channels[2]) + + b.cancel() +} + +// createRouteToBlinded queries for a route from alice to the blinded path +// provided. +// +//nolint:gomnd +func (b *blindedForwardTest) createRouteToBlinded(paymentAmt int64, + route *routing.BlindedPayment) *lnrpc.Route { + + intro := route.BlindedPath.IntroductionPoint.SerializeCompressed() + blinding := route.BlindedPath.BlindingPoint.SerializeCompressed() + + blindedRoute := &lnrpc.BlindedPath{ + IntroductionNode: intro, + BlindingPoint: blinding, + BlindedHops: make( + []*lnrpc.BlindedHop, + len(route.BlindedPath.BlindedHops), + ), + } + + for i, hop := range route.BlindedPath.BlindedHops { + blindedRoute.BlindedHops[i] = &lnrpc.BlindedHop{ + BlindedNode: hop.BlindedNodePub.SerializeCompressed(), + EncryptedData: hop.CipherText, + } + } + blindedPath := &lnrpc.BlindedPaymentPath{ + BlindedPath: blindedRoute, + BaseFeeMsat: uint64( + route.BaseFee, + ), + ProportionalFeeRate: route.ProportionalFeeRate, + TotalCltvDelta: uint32( + route.CltvExpiryDelta, + ), + } + + req := &lnrpc.QueryRoutesRequest{ + AmtMsat: paymentAmt, + // Our fee limit doesn't really matter, we just want to + // be able to make the payment. + FeeLimit: &lnrpc.FeeLimit{ + Limit: &lnrpc.FeeLimit_Percent{ + Percent: 50, + }, + }, + BlindedPaymentPaths: []*lnrpc.BlindedPaymentPath{ + blindedPath, + }, + } + + resp := b.ht.Alice.RPC.QueryRoutes(req) + require.Greater(b.ht, len(resp.Routes), 0, "no routes") + require.Len(b.ht, resp.Routes[0].Hops, 3, "unexpected route length") + + return resp.Routes[0] +} + +// sendBlindedPayment dispatches a payment to the route provided. +func (b *blindedForwardTest) sendBlindedPayment(ctx context.Context, + route *lnrpc.Route) { + + hash := sha256.Sum256(b.preimage[:]) + sendReq := &routerrpc.SendToRouteRequest{ + PaymentHash: hash[:], + Route: route, + } + + // Dispatch in a goroutine because this call is blocking - we assume + // that we'll have assertions that this payment is sent by the caller. + go func() { + b.ht.Alice.RPC.SendToRouteV2(sendReq) + }() +} + +// interceptFinalHop launches a goroutine to intercept Carol's htlcs and +// returns a closure that can be used to resolve intercepted htlcs. +// +//nolint:lll +func (b *blindedForwardTest) interceptFinalHop() func(routerrpc.ResolveHoldForwardAction) { + hash := sha256.Sum256(b.preimage[:]) + htlcReceived := make(chan *routerrpc.ForwardHtlcInterceptRequest) + + // Launch a goroutine which will receive from the interceptor and pipe + // it into our request channel. + go func() { + forward, err := b.carolInterceptor.Recv() + if err != nil { + b.ht.Fatalf("intercept receive failed: %v", err) + } + + if !bytes.Equal(forward.PaymentHash, hash[:]) { + b.ht.Fatalf("unexpected payment hash: %v", hash) + } + + select { + case htlcReceived <- forward: + + case <-time.After(lntest.DefaultTimeout): + b.ht.Fatal("timeout waiting to send intercepted htlc") + } + }() + + // Create a closure that will wait for the intercept request and + // resolve the HTLC with the appropriate action. + resolve := func(action routerrpc.ResolveHoldForwardAction) { + select { + case forward := <-htlcReceived: + resp := &routerrpc.ForwardHtlcInterceptResponse{ + IncomingCircuitKey: forward.IncomingCircuitKey, + } + + switch action { + case routerrpc.ResolveHoldForwardAction_FAIL: + resp.Action = routerrpc.ResolveHoldForwardAction_FAIL + + case routerrpc.ResolveHoldForwardAction_SETTLE: + resp.Action = routerrpc.ResolveHoldForwardAction_SETTLE + resp.Preimage = b.preimage[:] + + case routerrpc.ResolveHoldForwardAction_RESUME: + resp.Action = routerrpc.ResolveHoldForwardAction_RESUME + } + + require.NoError(b.ht, b.carolInterceptor.Send(resp)) + + case <-time.After(lntest.DefaultTimeout): + b.ht.Fatal("timeout waiting for htlc intercept") + } + } + + return resolve +} + +// setupFourHopNetwork creates a network with the following topology and +// liquidity: +// Alice (100k)----- Bob (100k) ----- Carol (100k) ----- Dave +// +// The funding outpoint for AB / BC / CD are returned in-order. +func setupFourHopNetwork(ht *lntest.HarnessTest, + carol, dave *node.HarnessNode) []*lnrpc.ChannelPoint { + + const chanAmt = btcutil.Amount(100000) + var networkChans []*lnrpc.ChannelPoint + + // Open a channel with 100k satoshis between Alice and Bob with Alice + // being the sole funder of the channel. + chanPointAlice := ht.OpenChannel( + ht.Alice, ht.Bob, lntest.OpenChannelParams{ + Amt: chanAmt, + }, + ) + networkChans = append(networkChans, chanPointAlice) + + // Create a channel between bob and carol. + ht.EnsureConnected(ht.Bob, carol) + chanPointBob := ht.OpenChannel( + ht.Bob, carol, lntest.OpenChannelParams{ + Amt: chanAmt, + }, + ) + networkChans = append(networkChans, chanPointBob) + + // Fund carol and connect her and dave so that she can create a channel + // between them. + ht.FundCoins(btcutil.SatoshiPerBitcoin, carol) + ht.EnsureConnected(carol, dave) + + chanPointCarol := ht.OpenChannel( + carol, dave, lntest.OpenChannelParams{ + Amt: chanAmt, + }, + ) + networkChans = append(networkChans, chanPointCarol) + + // Wait for all nodes to have seen all channels. + nodes := []*node.HarnessNode{ht.Alice, ht.Bob, carol, dave} + for _, chanPoint := range networkChans { + for _, node := range nodes { + ht.AssertTopologyChannelOpen(node, chanPoint) + } + } + + return []*lnrpc.ChannelPoint{ + chanPointAlice, + chanPointBob, + chanPointCarol, + } +} + +// createBlindedRoute creates a blinded route to the recipient node provided. +// The set of hops is expected to start at the introduction node and end at +// the recipient. +func (b *blindedForwardTest) createBlindedRoute(hops []*forwardingEdge, + dest *btcec.PublicKey, finalCLTV uint16) *routing.BlindedPayment { + + // Create a path with space for each of our hops + the destination + // node. We include our passed final cltv delta here because blinded + // paths include the delta in the blinded portion (not the invoice). + blindedPayment := &routing.BlindedPayment{ + CltvExpiryDelta: finalCLTV, + } + + pathLength := len(hops) + 1 + blindedPath := make([]*sphinx.HopInfo, pathLength) + + // Run forwards through our hops to create blinded route data for each + // node with the next node's short channel id and our payment + // constraints. + for i := 0; i < len(hops); i++ { + node := hops[i] + scid := node.channelID + + // Set the relay information for this edge based on its policy. + delta := uint16(node.edge.TimeLockDelta) + relayInfo := &record.PaymentRelayInfo{ + BaseFee: uint32(node.edge.FeeBaseMsat), + FeeRate: uint32(node.edge.FeeRateMilliMsat), + CltvExpiryDelta: delta, + } + + // We set our constraints with our edge's actual htlc min, and + // an arbitrary maximum expiry (since it's just an anti-probing + // mechanism). + constraints := &record.PaymentConstraints{ + HtlcMinimumMsat: lnwire.MilliSatoshi(node.edge.MinHtlc), + MaxCltvExpiry: 100000, + } + + // Add CLTV delta of each hop to the blinded payment. + blindedPayment.CltvExpiryDelta += delta + + // Encode the route's blinded data and include it in the + // blinded hop. + payload := record.NewBlindedRouteData( + scid, nil, *relayInfo, constraints, nil, + ) + payloadBytes, err := record.EncodeBlindedRouteData(payload) + require.NoError(b.ht, err) + + blindedPath[i] = &sphinx.HopInfo{ + NodePub: node.pubkey, + PlainText: payloadBytes, + } + } + + // Next, we'll run backwards through our route to build up the aggregate + // fees for the blinded payment as a whole. This is done in a separate + // loop for the sake of readability. + // + // For blinded path aggregated fees, we start at the receiving node + // and add up base an proportional fees *including* the fees that we'll + // charge on accumulated fees. We use the int ceiling to round up so + // that the sender will always over-pay, ensuring that we don't round + // down along the route leaving one forwarding node short of what + // they're expecting. + var ( + hopCount = len(hops) - 1 + currentHopBaseFee = hops[hopCount].edge.FeeBaseMsat + currentHopPropFee = hops[hopCount].edge.FeeRateMilliMsat + feeParts int64 = 1e6 + ) + + // Note: the spec says to iterate backwards, but then uses n / n +1 to + // express the "next" hop in the route going backwards. This works for + // languages where we can iterate backwards and get an increasing + // index, but since we're counting backwards we use n-1 instead. + // + // Specification reference: + //nolint:lll + // https://github.com/lightning/bolts/blob/60de4a09727c20dea330f9ee8313034de6e50594/proposals/route-blinding.md?plain=1#L253-L254 + for i := hopCount; i > 0; i-- { + preceedingBase := hops[i-1].edge.FeeBaseMsat + preceedingProp := hops[i-1].edge.FeeBaseMsat + + // Separate numerator from ceiling division to break up large + // lines. + baseFeeNumerator := preceedingBase*feeParts + + currentHopBaseFee*(feeParts+preceedingProp) + currentHopBaseFee = (baseFeeNumerator + feeParts - 1) / feeParts + + propFeeNumerator := (currentHopPropFee+preceedingProp)* + feeParts + currentHopPropFee*preceedingProp + currentHopPropFee = (propFeeNumerator + feeParts - 1) / feeParts + } + + blindedPayment.BaseFee = uint32(currentHopBaseFee) + blindedPayment.ProportionalFeeRate = uint32(currentHopPropFee) + + // Add our destination node at the end of the path. We don't need to + // add any forwarding parameters because we're at the final hop. + payloadBytes, err := record.EncodeBlindedRouteData( + // TODO: we don't have support for the final hop fields, + // because only forwarding is supported. We add a next + // node ID here so that it _looks like_ a valid + // forwarding hop (though in reality it's the last + // hop). + record.NewBlindedRouteData( + lnwire.NewShortChanIDFromInt(100), nil, + record.PaymentRelayInfo{}, nil, nil, + ), + ) + require.NoError(b.ht, err, "final payload") + + blindedPath[pathLength-1] = &sphinx.HopInfo{ + NodePub: dest, + PlainText: payloadBytes, + } + + // Blind the path. + blindingKey, err := btcec.NewPrivateKey() + require.NoError(b.ht, err) + + blindedPayment.BlindedPath, err = sphinx.BuildBlindedPath( + blindingKey, blindedPath, + ) + require.NoError(b.ht, err, "build blinded path") + + return blindedPayment +} + +// forwardingEdge contains the channel id/source public key for a forwarding +// edge and the policy associated with the channel in that direction. +type forwardingEdge struct { + pubkey *btcec.PublicKey + channelID lnwire.ShortChannelID + edge *lnrpc.RoutingPolicy +} + +func getForwardingEdge(ht *lntest.HarnessTest, + node *node.HarnessNode, chanID uint64) *forwardingEdge { + + chanInfo := node.RPC.GetChanInfo(&lnrpc.ChanInfoRequest{ + ChanId: chanID, + }) + + pubkey, err := btcec.ParsePubKey(node.PubKey[:]) + require.NoError(ht, err, "%v pubkey", node.Cfg.Name) + + fwdEdge := &forwardingEdge{ + pubkey: pubkey, + channelID: lnwire.NewShortChanIDFromInt(chanID), + } + + if chanInfo.Node1Pub == node.PubKeyStr { + fwdEdge.edge = chanInfo.Node1Policy + } else { + require.Equal(ht, node.PubKeyStr, chanInfo.Node2Pub, + "policy edge sanity check") + + fwdEdge.edge = chanInfo.Node2Policy + } + + return fwdEdge +} + +// testForwardBlindedRoute tests lnd's ability to forward payments in a blinded +// route. +func testForwardBlindedRoute(ht *lntest.HarnessTest) { + ctx, testCase := newBlindedForwardTest(ht) + defer testCase.cleanup() + + route := testCase.setup(ctx) + blindedRoute := testCase.createRouteToBlinded(10_000_000, route) + + // Receiving via blinded routes is not yet supported, so Dave won't be + // able to process the payment. + // + // We have an interceptor at our disposal that will catch htlcs as they + // are forwarded (ie, it won't intercept a HTLC that dave is receiving, + // since no forwarding occurs). We initiate this interceptor with + // Carol, so that we can catch it and settle on the outgoing link to + // Dave. Once we hit the outgoing link, we know that we successfully + // parsed the htlc, so this is an acceptable compromise. + // Assert that our interceptor has exited without an error. + resolveHTLC := testCase.interceptFinalHop() + + // Once our interceptor is set up, we can send the blinded payment. + testCase.sendBlindedPayment(ctx, blindedRoute) + + // Wait for the HTLC to be active on Alice's channel. + hash := sha256.Sum256(testCase.preimage[:]) + ht.AssertOutgoingHTLCActive(ht.Alice, testCase.channels[0], hash[:]) + ht.AssertOutgoingHTLCActive(ht.Bob, testCase.channels[1], hash[:]) + + // Intercept and settle the HTLC. + resolveHTLC(routerrpc.ResolveHoldForwardAction_SETTLE) + + // Wait for the HTLC to reflect as settled for Alice. + preimage, err := lntypes.MakePreimage(testCase.preimage[:]) + require.NoError(ht, err) + ht.AssertPaymentStatus(ht.Alice, preimage, lnrpc.Payment_SUCCEEDED) + + // Assert that the HTLC has settled before test cleanup runs so that + // we can cooperatively close all channels. + ht.AssertHLTCNotActive(ht.Bob, testCase.channels[1], hash[:]) + ht.AssertHLTCNotActive(ht.Alice, testCase.channels[0], hash[:]) +} diff --git a/lncfg/protocol.go b/lncfg/protocol.go index f8ac08e86bd..59027a09b75 100644 --- a/lncfg/protocol.go +++ b/lncfg/protocol.go @@ -54,6 +54,17 @@ type ProtocolOptions struct { // also mean that we won't respond with timestamps if requested by our // peers. NoTimestampQueryOption bool `long:"no-timestamp-query-option" description:"do not query syncing peers for announcement timestamps and do not respond with timestamps if requested"` + + // NoRouteBlindingOption disables forwarding of payments in blinded routes. + NoRouteBlindingOption bool `long:"no-route-blinding" description:"do not forward payments that are a part of a blinded route"` +} + +// DefaultProtocol returns a protocol config with route blinding turned off, +// temporarily in place until full handling of blinded route errors is merged. +func DefaultProtocol() *ProtocolOptions { + return &ProtocolOptions{ + NoRouteBlindingOption: true, + } } // Wumbo returns true if lnd should permit the creation and acceptance of wumbo @@ -97,3 +108,8 @@ func (l *ProtocolOptions) NoAnySegwit() bool { func (l *ProtocolOptions) NoTimestampsQuery() bool { return l.NoTimestampQueryOption } + +// NoRouteBlinding returns true if forwarding of blinded payments is disabled. +func (l *ProtocolOptions) NoRouteBlinding() bool { + return l.NoRouteBlindingOption +} diff --git a/lncfg/protocol_integration.go b/lncfg/protocol_integration.go index ff74ba9e908..f44aa124693 100644 --- a/lncfg/protocol_integration.go +++ b/lncfg/protocol_integration.go @@ -57,6 +57,16 @@ type ProtocolOptions struct { // also mean that we won't respond with timestamps if requested by our // peers. NoTimestampQueryOption bool `long:"no-timestamp-query-option" description:"do not query syncing peers for announcement timestamps and do not respond with timestamps if requested"` + + // NoRouteBlindingOption disables forwarding of payments in blinded routes. + NoRouteBlindingOption bool `long:"no-route-blinding" description:"do not forward payments that are a part of a blinded route"` +} + +// DefaultProtocol returns a protocol config with route blinding turned on, +// so that itests can run against route blinding features even while we've +// got it turned off for the daemon (pending completion of error handling). +func DefaultProtocol() *ProtocolOptions { + return &ProtocolOptions{} } // Wumbo returns true if lnd should permit the creation and acceptance of wumbo @@ -92,3 +102,8 @@ func (l *ProtocolOptions) ZeroConf() bool { func (l *ProtocolOptions) NoAnySegwit() bool { return l.NoOptionAnySegwit } + +// NoRouteBlinding returns true if forwarding of blinded payments is disabled. +func (l *ProtocolOptions) NoRouteBlinding() bool { + return l.NoRouteBlindingOption +} diff --git a/lnwallet/channel.go b/lnwallet/channel.go index 1b6e71ffd6f..1af0e4b5ffc 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -31,7 +31,6 @@ import ( "github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/shachain" - "github.com/lightningnetwork/lnd/tlv" ) var ( @@ -377,7 +376,7 @@ type PaymentDescriptor struct { // This value is set for nodes that are relaying payments inside of a // blinded route (ie, not the introduction node) from update_add_htlc's // TLVs. - BlindingPoint *btcec.PublicKey + BlindingPoint lnwire.BlindingPointRecord } // PayDescsFromRemoteLogUpdates converts a slice of LogUpdates received from the @@ -418,7 +417,7 @@ func PayDescsFromRemoteLogUpdates(chanID lnwire.ShortChannelID, height uint64, Height: height, Index: uint16(i), }, - BlindingPoint: wireMsg.BlingingPointOrNil(), + BlindingPoint: pd.BlindingPoint, } pd.OnionBlob = make([]byte, len(wireMsg.OnionBlob)) copy(pd.OnionBlob[:], wireMsg.OnionBlob[:]) @@ -742,16 +741,9 @@ func (c *commitment) toDiskCommit(ourCommit bool) *channeldb.ChannelCommitment { HtlcIndex: htlc.HtlcIndex, LogIndex: htlc.LogIndex, Incoming: false, + BlindingPoint: htlc.BlindingPoint, } copy(h.OnionBlob[:], htlc.OnionBlob) - if htlc.BlindingPoint != nil { - h.BlindingPoint = tlv.SomeRecordT( - //nolint:lll - tlv.NewPrimitiveRecord[lnwire.BlindingPointTlvType]( - htlc.BlindingPoint, - ), - ) - } if ourCommit && htlc.sig != nil { h.Signature = htlc.sig.Serialize() @@ -774,16 +766,9 @@ func (c *commitment) toDiskCommit(ourCommit bool) *channeldb.ChannelCommitment { HtlcIndex: htlc.HtlcIndex, LogIndex: htlc.LogIndex, Incoming: true, + BlindingPoint: htlc.BlindingPoint, } copy(h.OnionBlob[:], htlc.OnionBlob) - if htlc.BlindingPoint != nil { - h.BlindingPoint = tlv.SomeRecordT( - //nolint:lll - tlv.NewPrimitiveRecord[lnwire.BlindingPointTlvType]( - htlc.BlindingPoint, - ), - ) - } if ourCommit && htlc.sig != nil { h.Signature = htlc.sig.Serialize() } @@ -866,7 +851,7 @@ func (lc *LightningChannel) diskHtlcToPayDesc(feeRate chainfee.SatPerKWeight, // With the scripts reconstructed (depending on if this is our commit // vs theirs or a pending commit for the remote party), we can now // re-create the original payment descriptor. - pd = PaymentDescriptor{ + return PaymentDescriptor{ RHash: htlc.RHash, Timeout: htlc.RefundTimeout, Amount: htlc.Amt, @@ -880,15 +865,8 @@ func (lc *LightningChannel) diskHtlcToPayDesc(feeRate chainfee.SatPerKWeight, ourWitnessScript: ourWitnessScript, theirPkScript: theirP2WSH, theirWitnessScript: theirWitnessScript, - } - - htlc.BlindingPoint.WhenSome(func(b tlv.RecordT[ - lnwire.BlindingPointTlvType, *btcec.PublicKey]) { - - pd.BlindingPoint = b.Val - }) - - return pd, nil + BlindingPoint: htlc.BlindingPoint, + }, nil } // extractPayDescs will convert all HTLC's present within a disk commit state @@ -1577,7 +1555,7 @@ func (lc *LightningChannel) logUpdateToPayDesc(logUpdate *channeldb.LogUpdate, HtlcIndex: wireMsg.ID, LogIndex: logUpdate.LogIndex, addCommitHeightRemote: commitHeight, - BlindingPoint: wireMsg.BlingingPointOrNil(), + BlindingPoint: wireMsg.BlindingPoint, } pd.OnionBlob = make([]byte, len(wireMsg.OnionBlob)) copy(pd.OnionBlob[:], wireMsg.OnionBlob[:]) @@ -1775,7 +1753,7 @@ func (lc *LightningChannel) remoteLogUpdateToPayDesc(logUpdate *channeldb.LogUpd HtlcIndex: wireMsg.ID, LogIndex: logUpdate.LogIndex, addCommitHeightLocal: commitHeight, - BlindingPoint: wireMsg.BlingingPointOrNil(), + BlindingPoint: wireMsg.BlindingPoint, } pd.OnionBlob = make([]byte, len(wireMsg.OnionBlob)) copy(pd.OnionBlob, wireMsg.OnionBlob[:]) @@ -3631,21 +3609,14 @@ func (lc *LightningChannel) createCommitDiff( switch pd.EntryType { case Add: htlc := &lnwire.UpdateAddHTLC{ - ChanID: chanID, - ID: pd.HtlcIndex, - Amount: pd.Amount, - Expiry: pd.Timeout, - PaymentHash: pd.RHash, + ChanID: chanID, + ID: pd.HtlcIndex, + Amount: pd.Amount, + Expiry: pd.Timeout, + PaymentHash: pd.RHash, + BlindingPoint: pd.BlindingPoint, } copy(htlc.OnionBlob[:], pd.OnionBlob) - if pd.BlindingPoint != nil { - htlc.BlindingPoint = tlv.SomeRecordT( - //nolint:lll - tlv.NewPrimitiveRecord[lnwire.BlindingPointTlvType]( - pd.BlindingPoint, - ), - ) - } logUpdate.UpdateMsg = htlc // Gather any references for circuits opened by this Add @@ -3775,21 +3746,13 @@ func (lc *LightningChannel) getUnsignedAckedUpdates() []channeldb.LogUpdate { // four messages that it corresponds to. switch pd.EntryType { case Add: - var b lnwire.BlindingPointRecord - if pd.BlindingPoint != nil { - tlv.SomeRecordT( - //nolint:lll - tlv.NewPrimitiveRecord[lnwire.BlindingPointTlvType](pd.BlindingPoint), - ) - } - htlc := &lnwire.UpdateAddHTLC{ ChanID: chanID, ID: pd.HtlcIndex, Amount: pd.Amount, Expiry: pd.Timeout, PaymentHash: pd.RHash, - BlindingPoint: b, + BlindingPoint: pd.BlindingPoint, } copy(htlc.OnionBlob[:], pd.OnionBlob) logUpdate.UpdateMsg = htlc @@ -5784,19 +5747,12 @@ func (lc *LightningChannel) ReceiveRevocation(revMsg *lnwire.RevokeAndAck) ( switch pd.EntryType { case Add: htlc := &lnwire.UpdateAddHTLC{ - ChanID: chanID, - ID: pd.HtlcIndex, - Amount: pd.Amount, - Expiry: pd.Timeout, - PaymentHash: pd.RHash, - } - if pd.BlindingPoint != nil { - htlc.BlindingPoint = tlv.SomeRecordT( - //nolint:lll - tlv.NewPrimitiveRecord[lnwire.BlindingPointTlvType]( - pd.BlindingPoint, - ), - ) + ChanID: chanID, + ID: pd.HtlcIndex, + Amount: pd.Amount, + Expiry: pd.Timeout, + PaymentHash: pd.RHash, + BlindingPoint: pd.BlindingPoint, } copy(htlc.OnionBlob[:], pd.OnionBlob) logUpdate.UpdateMsg = htlc @@ -6135,7 +6091,7 @@ func (lc *LightningChannel) htlcAddDescriptor(htlc *lnwire.UpdateAddHTLC, HtlcIndex: lc.localUpdateLog.htlcCounter, OnionBlob: htlc.OnionBlob[:], OpenCircuitKey: openKey, - BlindingPoint: htlc.BlingingPointOrNil(), + BlindingPoint: htlc.BlindingPoint, } } @@ -6193,7 +6149,7 @@ func (lc *LightningChannel) ReceiveHTLC(htlc *lnwire.UpdateAddHTLC) (uint64, err LogIndex: lc.remoteUpdateLog.logIndex, HtlcIndex: lc.remoteUpdateLog.htlcCounter, OnionBlob: htlc.OnionBlob[:], - BlindingPoint: htlc.BlingingPointOrNil(), + BlindingPoint: htlc.BlindingPoint, } localACKedIndex := lc.remoteCommitChain.tail().ourMessageIndex diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go index d224b45983e..7ef5c118aa8 100644 --- a/lnwallet/channel_test.go +++ b/lnwallet/channel_test.go @@ -11045,7 +11045,8 @@ func TestBlindingPointPersistence(t *testing.T) { // Assert that the blinding point is restored from disk. remoteCommit := aliceChannel.remoteCommitChain.tip() require.Len(t, remoteCommit.outgoingHTLCs, 1) - require.Equal(t, blinding, remoteCommit.outgoingHTLCs[0].BlindingPoint) + require.Equal(t, blinding, + remoteCommit.outgoingHTLCs[0].BlindingPoint.UnwrapOrFailV(t)) // Next, update bob's commitment and assert that we can still retrieve // his incoming blinding point after restart. @@ -11061,5 +11062,6 @@ func TestBlindingPointPersistence(t *testing.T) { // Assert that Bob is able to recover the blinding point from disk. bobCommit := bobChannel.localCommitChain.tip() require.Len(t, bobCommit.incomingHTLCs, 1) - require.Equal(t, blinding, bobCommit.incomingHTLCs[0].BlindingPoint) + require.Equal(t, blinding, + bobCommit.incomingHTLCs[0].BlindingPoint.UnwrapOrFailV(t)) } diff --git a/lnwire/update_add_htlc.go b/lnwire/update_add_htlc.go index 951dc7f54cf..8a40710e82a 100644 --- a/lnwire/update_add_htlc.go +++ b/lnwire/update_add_htlc.go @@ -78,19 +78,6 @@ type UpdateAddHTLC struct { ExtraData ExtraOpaqueData } -// BlingingPointOrNil returns the blinding point associated with the update, or -// nil. -func (c *UpdateAddHTLC) BlingingPointOrNil() *btcec.PublicKey { - var blindingPoint *btcec.PublicKey - c.BlindingPoint.WhenSome(func(b tlv.RecordT[BlindingPointTlvType, - *btcec.PublicKey]) { - - blindingPoint = b.Val - }) - - return blindingPoint -} - // NewUpdateAddHTLC returns a new empty UpdateAddHTLC message. func NewUpdateAddHTLC() *UpdateAddHTLC { return &UpdateAddHTLC{} diff --git a/peer/brontide.go b/peer/brontide.go index 541c0f358a9..187fbafe179 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -365,6 +365,11 @@ type Config struct { // this across multiple Peer struct instances. PongBuf []byte + // Adds the option to disable forwarding payments in blinded routes + // by failing back any blinding-related payloads as if they were + // invalid. + DisallowRouteBlinding bool + // Quit is the server's quit channel. If this is closed, we halt operation. Quit chan struct{} } @@ -1155,6 +1160,7 @@ func (p *Brontide) addLink(chanPoint *wire.OutPoint, HtlcNotifier: p.cfg.HtlcNotifier, GetAliases: p.cfg.GetAliases, PreviouslySentShutdown: shutdownMsg, + DisallowRouteBlinding: p.cfg.DisallowRouteBlinding, } // Before adding our new link, purge the switch of any pending or live diff --git a/sample-lnd.conf b/sample-lnd.conf index dd538b07a29..08a79bde5b9 100644 --- a/sample-lnd.conf +++ b/sample-lnd.conf @@ -1282,6 +1282,9 @@ ; Set to enable support for the experimental taproot channel type. ; protocol.simple-taproot-chans=false +; Set to disable blinded route forwarding. +; protocol.no-route-blinding=false + [db] ; The selected database backend. The current default backend is "bolt". lnd diff --git a/server.go b/server.go index ea51242e4ae..2c8b75af10c 100644 --- a/server.go +++ b/server.go @@ -3872,6 +3872,7 @@ func (s *server) peerConnected(conn net.Conn, connReq *connmgr.ConnReq, GetAliases: s.aliasMgr.GetAliases, RequestAlias: s.aliasMgr.RequestAlias, AddLocalAlias: s.aliasMgr.AddLocalAlias, + DisallowRouteBlinding: s.cfg.ProtocolOptions.NoRouteBlinding(), Quit: s.quit, }