diff --git a/go.mod b/go.mod index d6ea9268c09..83df59f7701 100644 --- a/go.mod +++ b/go.mod @@ -224,4 +224,7 @@ replace google.golang.org/protobuf => github.com/lightninglabs/protobuf-go-hex-d // well). go 1.25.5 +// Temporary replace until dependent PR is merged in lightning-onion. +replace github.com/lightningnetwork/lightning-onion => github.com/joostjager/lightning-onion v0.0.0-20260312135706-2dd58e7b9794 + retract v0.0.2 diff --git a/go.sum b/go.sum index 0ca378d9d38..09ee8d3a5dc 100644 --- a/go.sum +++ b/go.sum @@ -308,6 +308,8 @@ github.com/jessevdk/go-flags v1.6.1 h1:Cvu5U8UGrLay1rZfv/zP7iLpSHGUZ/Ou68T0iX1bB github.com/jessevdk/go-flags v1.6.1/go.mod h1:Mk8T1hIAWpOiJiHa9rJASDK2UGWji0EuPGBnNLMooyc= github.com/jonboulle/clockwork v0.2.2 h1:UOGuzwb1PwsrDAObMuhUnj0p5ULPj8V/xJ7Kx9qUBdQ= github.com/jonboulle/clockwork v0.2.2/go.mod h1:Pkfl5aHPm1nk2H9h0bjmnJD/BcgbGXUBGnn1kMkgxc8= +github.com/joostjager/lightning-onion v0.0.0-20260312135706-2dd58e7b9794 h1:xKnZDFhNa3yoJWF2XrSKzSx76qXd79FF46vA8Jym16s= +github.com/joostjager/lightning-onion v0.0.0-20260312135706-2dd58e7b9794/go.mod h1:nP85zMHG7c0si/eHBbSQpuDCtnIXfSvFrK3tW6YWzmU= github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= github.com/jrick/logrotate v1.0.0/go.mod h1:LNinyqDIJnpAur+b8yyulnQw/wDuN1+BYKlTRt3OuAQ= github.com/jrick/logrotate v1.1.2 h1:6ePk462NCX7TfKtNp5JJ7MbA2YIslkpfgP03TlTYMN0= @@ -370,8 +372,6 @@ github.com/lightninglabs/neutrino/cache v1.1.3 h1:rgnabC41W+XaPuBTQrdeFjFCCAVKh1 github.com/lightninglabs/neutrino/cache v1.1.3/go.mod h1:qxkJb+pUxR5p84jl5uIGFCR4dGdFkhNUwMSxw3EUWls= github.com/lightninglabs/protobuf-go-hex-display v1.33.0-hex-display h1:Y2WiPkBS/00EiEg0qp0FhehxnQfk3vv8U6Xt3nN+rTY= github.com/lightninglabs/protobuf-go-hex-display v1.33.0-hex-display/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= -github.com/lightningnetwork/lightning-onion v1.3.0 h1:FqILgHjD6euc/Muo1VOzZ4+XDPuFnw6EYROBq0rR/5c= -github.com/lightningnetwork/lightning-onion v1.3.0/go.mod h1:nP85zMHG7c0si/eHBbSQpuDCtnIXfSvFrK3tW6YWzmU= github.com/lightningnetwork/lnd/cert v1.2.2 h1:71YK6hogeJtxSxw2teq3eGeuy4rHGKcFf0d0Uy4qBjI= github.com/lightningnetwork/lnd/cert v1.2.2/go.mod h1:jQmFn/Ez4zhDgq2hnYSw8r35bqGVxViXhX6Cd7HXM6U= github.com/lightningnetwork/lnd/clock v1.1.1 h1:OfR3/zcJd2RhH0RU+zX/77c0ZiOnIMsDIBjgjWdZgA0= diff --git a/htlcswitch/attributable_failure_test.go b/htlcswitch/attributable_failure_test.go new file mode 100644 index 00000000000..7b443fcbb03 --- /dev/null +++ b/htlcswitch/attributable_failure_test.go @@ -0,0 +1,241 @@ +package htlcswitch + +import ( + "bytes" + "testing" + "time" + + "github.com/btcsuite/btcd/btcec/v2" + sphinx "github.com/lightningnetwork/lightning-onion" + "github.com/lightningnetwork/lnd/htlcswitch/hop" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/stretchr/testify/require" +) + +// deriveSharedSecrets derives the shared secrets for each hop along the payment +// path using the session key, mirroring the logic of sphinx's internal +// generateSharedSecrets. +func deriveSharedSecrets(t *testing.T, paymentPath []*btcec.PublicKey, + sessionKey *btcec.PrivateKey) []sphinx.Hash256 { + + t.Helper() + + numHops := len(paymentPath) + secrets := make([]sphinx.Hash256, numHops) + + ephemECDH := &sphinx.PrivKeyECDH{PrivKey: sessionKey} + + // First hop. + ss, err := ephemECDH.ECDH(paymentPath[0]) + require.NoError(t, err) + secrets[0] = ss + + // Subsequent hops: derive the next ephemeral private key using the + // blinding factor. + for i := 1; i < numHops; i++ { + nextPriv, err := sphinx.NextEphemeralPriv( + ephemECDH, paymentPath[i-1], + ) + require.NoError(t, err) + + ephemECDH = &sphinx.PrivKeyECDH{PrivKey: nextPriv} + + ss, err = ephemECDH.ECDH(paymentPath[i]) + require.NoError(t, err) + secrets[i] = ss + } + + return secrets +} + +// TestAttributableFailureEndToEnd exercises the full encrypt → intermediate +// encrypt → decrypt flow with attribution data and validates that HoldTimes +// are correctly populated. +func TestAttributableFailureEndToEnd(t *testing.T) { + t.Parallel() + + const numHops = 4 + + // Generate random node keys for the payment path. + paymentPath := make([]*btcec.PublicKey, numHops) + for i := 0; i < numHops; i++ { + privKey, err := btcec.NewPrivateKey() + require.NoError(t, err) + paymentPath[i] = privKey.PubKey() + } + + // Use a deterministic session key. + sessionKey, _ := btcec.PrivKeyFromBytes( + bytes.Repeat([]byte{0x42}, 32), + ) + + // Derive per-hop shared secrets. + sharedSecrets := deriveSharedSecrets(t, paymentPath, sessionKey) + + // The failing node is hop index 2 (third node, 0-indexed). + failingHopIdx := 2 + + // Create a failure message at the failing hop. + failureMsg := lnwire.NewFailIncorrectDetails(1000, 100) + + // Create the error encrypter at the failing hop, with a creation time + // slightly in the past to get a non-zero hold time. + failEncrypter := hop.NewSphinxErrorEncrypter( + paymentPath[failingHopIdx], + sharedSecrets[failingHopIdx], + ) + failEncrypter.CreatedAt = time.Now().Add(-200 * time.Millisecond) + + // Encrypt at the origin of the failure. + reason, attrData, err := failEncrypter.EncryptFirstHop(failureMsg) + require.NoError(t, err) + require.NotEmpty(t, reason) + require.NotEmpty(t, attrData, "attribution data should be populated") + + // Wrap the attribution data in ExtraOpaqueData for transmission. + extraData, err := lnwire.AttrDataToExtraData(attrData) + require.NoError(t, err) + + // Intermediate encrypt at each hop back to the sender. + for i := failingHopIdx - 1; i >= 0; i-- { + intermediateEnc := hop.NewSphinxErrorEncrypter( + paymentPath[i], + sharedSecrets[i], + ) + // Set a slightly older creation time to simulate hold time. + intermediateEnc.CreatedAt = time.Now().Add( + -100 * time.Millisecond, + ) + + // Extract attr data from the extra data (as it would come from + // the wire message). + attrData, err = lnwire.ExtraDataToAttrData(extraData) + require.NoError(t, err) + + reason, attrData, err = intermediateEnc.IntermediateEncrypt( + reason, attrData, + ) + require.NoError(t, err) + + extraData, err = lnwire.AttrDataToExtraData(attrData) + require.NoError(t, err) + } + + // Now decrypt at the sender using the SphinxErrorDecrypter. + circuit := &sphinx.Circuit{ + SessionKey: sessionKey, + PaymentPath: paymentPath, + } + decrypter := NewSphinxErrorDecrypter(circuit) + + attrData, err = lnwire.ExtraDataToAttrData(extraData) + require.NoError(t, err) + + fwdErr, err := decrypter.DecryptError(reason, attrData) + require.NoError(t, err) + + // Verify the failure source is identified correctly. + // SenderIdx is 1-indexed (0 = self), so failing hop index 2 means + // SenderIdx = 3. + require.Equal(t, failingHopIdx+1, fwdErr.FailureSourceIdx, + "failure source index mismatch") + + // Verify we got the right failure message back. + msg := fwdErr.WireMessage() + incorrectDetails, ok := msg.(*lnwire.FailIncorrectDetails) + require.True(t, ok, "expected FailIncorrectDetails, got %T", + fwdErr.WireMessage()) + require.EqualValues(t, 1000, incorrectDetails.Amount()) + require.EqualValues(t, 100, incorrectDetails.Height()) + + // Verify that HoldTimes are populated. We should have hold times for + // hops 1 through failingHopIdx (the failing node plus intermediates). + require.NotEmpty(t, fwdErr.HoldTimes, + "expected non-empty hold times") +} + +// TestAttributableFailureWithoutAttrData tests that decryption works without +// attribution data (backward compatibility with non-attributable errors). +func TestAttributableFailureWithoutAttrData(t *testing.T) { + t.Parallel() + + const numHops = 3 + + paymentPath := make([]*btcec.PublicKey, numHops) + for i := 0; i < numHops; i++ { + privKey, err := btcec.NewPrivateKey() + require.NoError(t, err) + paymentPath[i] = privKey.PubKey() + } + + sessionKey, _ := btcec.PrivKeyFromBytes( + bytes.Repeat([]byte{0x33}, 32), + ) + + sharedSecrets := deriveSharedSecrets(t, paymentPath, sessionKey) + + // Failing hop is the last node. + failingHopIdx := numHops - 1 + + failureMsg := lnwire.NewFailIncorrectDetails(500, 50) + + failEncrypter := hop.NewSphinxErrorEncrypter( + paymentPath[failingHopIdx], + sharedSecrets[failingHopIdx], + ) + + reason, _, err := failEncrypter.EncryptFirstHop(failureMsg) + require.NoError(t, err) + + // Intermediate hops encrypt WITHOUT using attribution data (passing + // nil), simulating nodes that don't support attributable failures. + for i := failingHopIdx - 1; i >= 0; i-- { + intermediateEnc := hop.NewSphinxErrorEncrypter( + paymentPath[i], + sharedSecrets[i], + ) + + reason, _, err = intermediateEnc.IntermediateEncrypt( + reason, nil, + ) + require.NoError(t, err) + } + + // Decrypt at the sender without attribution data. + circuit := &sphinx.Circuit{ + SessionKey: sessionKey, + PaymentPath: paymentPath, + } + decrypter := NewSphinxErrorDecrypter(circuit) + + fwdErr, err := decrypter.DecryptError(reason, nil) + require.NoError(t, err) + + // The failure source should still be correctly identified via the + // legacy HMAC-based mechanism. + require.Equal(t, failingHopIdx+1, fwdErr.FailureSourceIdx) + + msg := fwdErr.WireMessage() + incorrectDetails, ok := msg.(*lnwire.FailIncorrectDetails) + require.True(t, ok) + require.EqualValues(t, 500, incorrectDetails.Amount()) +} + +// TestNewForwardingErrorHoldTimes verifies that NewForwardingError correctly +// stores and exposes HoldTimes. +func TestNewForwardingErrorHoldTimes(t *testing.T) { + t.Parallel() + + holdTimes := []uint32{10, 20, 30, 40} + failure := lnwire.NewFailIncorrectDetails(100, 10) + + fwdErr := NewForwardingError(failure, 3, holdTimes) + + require.Equal(t, 3, fwdErr.FailureSourceIdx) + require.Equal(t, holdTimes, fwdErr.HoldTimes) + require.NotNil(t, fwdErr.WireMessage()) + + // With nil hold times. + fwdErr2 := NewForwardingError(failure, 1, nil) + require.Nil(t, fwdErr2.HoldTimes) +} diff --git a/htlcswitch/circuit.go b/htlcswitch/circuit.go index eab1cdb2002..e18e1f0b21c 100644 --- a/htlcswitch/circuit.go +++ b/htlcswitch/circuit.go @@ -199,17 +199,18 @@ func (c *PaymentCircuit) Decode(r io.Reader) error { case hop.EncrypterTypeSphinx: // Sphinx encrypter was used as this is a forwarded HTLC. - c.ErrorEncrypter = hop.NewSphinxErrorEncrypter() + c.ErrorEncrypter = hop.NewSphinxErrorEncrypterUninitialized() case hop.EncrypterTypeMock: // Test encrypter. c.ErrorEncrypter = NewMockObfuscator() case hop.EncrypterTypeIntroduction: - c.ErrorEncrypter = hop.NewIntroductionErrorEncrypter() + c.ErrorEncrypter = + hop.NewIntroductionErrorEncrypterUninitialized() case hop.EncrypterTypeRelaying: - c.ErrorEncrypter = hop.NewRelayingErrorEncrypter() + c.ErrorEncrypter = hop.NewRelayingErrorEncrypterUninitialized() default: return UnknownEncrypterType(encrypterType) diff --git a/htlcswitch/circuit_map.go b/htlcswitch/circuit_map.go index 15d4b5ffca4..3be27ee9ceb 100644 --- a/htlcswitch/circuit_map.go +++ b/htlcswitch/circuit_map.go @@ -210,9 +210,9 @@ type CircuitMapConfig struct { FetchClosedChannels func( pendingOnly bool) ([]*channeldb.ChannelCloseSummary, error) - // ExtractErrorEncrypter derives the shared secret used to encrypt - // errors from the obfuscator's ephemeral public key. - ExtractErrorEncrypter hop.ErrorEncrypterExtracter + // ExtractSharedSecret derives the shared secret used to encrypt errors + // from the obfuscator's ephemeral public key. + ExtractSharedSecret hop.SharedSecretGenerator // CheckResolutionMsg checks whether a given resolution message exists // for the passed CircuitKey. @@ -632,9 +632,7 @@ func (cm *circuitMap) decodeCircuit(v []byte) (*PaymentCircuit, error) { // Otherwise, we need to reextract the encrypter, so that the shared // secret is rederived from what was decoded. - err := circuit.ErrorEncrypter.Reextract( - cm.cfg.ExtractErrorEncrypter, - ) + err := circuit.ErrorEncrypter.Reextract(cm.cfg.ExtractSharedSecret) if err != nil { return nil, err } diff --git a/htlcswitch/circuit_test.go b/htlcswitch/circuit_test.go index ddad11aca91..975a1c89899 100644 --- a/htlcswitch/circuit_test.go +++ b/htlcswitch/circuit_test.go @@ -65,16 +65,17 @@ func initTestExtracter() { onionProcessor := newOnionProcessor(nil) defer onionProcessor.Stop() - obfuscator, _ := onionProcessor.ExtractErrorEncrypter( + sharedSecret, failCode := onionProcessor.ExtractSharedSecret( testEphemeralKey, ) - sphinxExtracter, ok := obfuscator.(*hop.SphinxErrorEncrypter) - if !ok { - panic("did not extract sphinx error encrypter") + if failCode != lnwire.CodeNone { + panic("did not extract shared secret") } - testExtracter = sphinxExtracter + testExtracter = hop.NewSphinxErrorEncrypter( + testEphemeralKey, sharedSecret, + ) // We also set this error extracter on startup, otherwise it will be nil // at compile-time. @@ -106,10 +107,10 @@ func newCircuitMap(t *testing.T, resMsg bool) (*htlcswitch.CircuitMapConfig, db := makeCircuitDB(t, "") circuitMapCfg := &htlcswitch.CircuitMapConfig{ - DB: db, - FetchAllOpenChannels: db.ChannelStateDB().FetchAllOpenChannels, - FetchClosedChannels: db.ChannelStateDB().FetchClosedChannels, - ExtractErrorEncrypter: onionProcessor.ExtractErrorEncrypter, + DB: db, + FetchAllOpenChannels: db.ChannelStateDB().FetchAllOpenChannels, + FetchClosedChannels: db.ChannelStateDB().FetchClosedChannels, + ExtractSharedSecret: onionProcessor.ExtractSharedSecret, } if resMsg { @@ -216,7 +217,7 @@ func TestHalfCircuitSerialization(t *testing.T) { // encrypters, this will be a NOP. if circuit2.ErrorEncrypter != nil { err := circuit2.ErrorEncrypter.Reextract( - onionProcessor.ExtractErrorEncrypter, + onionProcessor.ExtractSharedSecret, ) if err != nil { t.Fatalf("unable to reextract sphinx error "+ @@ -643,11 +644,11 @@ func restartCircuitMap(t *testing.T, cfg *htlcswitch.CircuitMapConfig) ( // Reinitialize circuit map with same db path. db := makeCircuitDB(t, dbPath) cfg2 := &htlcswitch.CircuitMapConfig{ - DB: db, - FetchAllOpenChannels: db.ChannelStateDB().FetchAllOpenChannels, - FetchClosedChannels: db.ChannelStateDB().FetchClosedChannels, - ExtractErrorEncrypter: cfg.ExtractErrorEncrypter, - CheckResolutionMsg: cfg.CheckResolutionMsg, + DB: db, + FetchAllOpenChannels: db.ChannelStateDB().FetchAllOpenChannels, + FetchClosedChannels: db.ChannelStateDB().FetchClosedChannels, + ExtractSharedSecret: cfg.ExtractSharedSecret, + CheckResolutionMsg: cfg.CheckResolutionMsg, } cm2, err := htlcswitch.NewCircuitMap(cfg2) require.NoError(t, err, "unable to recreate persistent circuit map") diff --git a/htlcswitch/failure.go b/htlcswitch/failure.go index 373263381fd..fa75c724719 100644 --- a/htlcswitch/failure.go +++ b/htlcswitch/failure.go @@ -3,6 +3,7 @@ package htlcswitch import ( "bytes" "fmt" + "strings" sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/htlcswitch/hop" @@ -92,6 +93,12 @@ type ForwardingError struct { // be nil in the case where we fail to decode failure message sent by // a peer. msg lnwire.FailureMessage + + // HoldTimes is an array of hold times (in 100ms units) as reported from + // the nodes of the route. Multiply by 100 to get milliseconds. The + // first element corresponds to the first node after the sender node, + // with greater indices indicating nodes further down the route. + HoldTimes []uint32 } // WireMessage extracts a valid wire failure message from an internal @@ -116,11 +123,12 @@ func (f *ForwardingError) Error() string { // NewForwardingError creates a new payment error which wraps a wire error // with additional metadata. func NewForwardingError(failure lnwire.FailureMessage, - index int) *ForwardingError { + index int, holdTimes []uint32) *ForwardingError { return &ForwardingError{ FailureSourceIdx: index, msg: failure, + HoldTimes: holdTimes, } } @@ -140,7 +148,7 @@ type ErrorDecrypter interface { // hop, to the source of the error. A fully populated // lnwire.FailureMessage is returned along with the source of the // error. - DecryptError(lnwire.OpaqueReason) (*ForwardingError, error) + DecryptError(lnwire.OpaqueReason, []byte) (*ForwardingError, error) } // UnknownEncrypterType is an error message used to signal that an unexpected @@ -152,21 +160,19 @@ func (e UnknownEncrypterType) Error() string { return fmt.Sprintf("unknown error encrypter type: %d", e) } -// OnionErrorDecrypter is the interface that provides onion level error -// decryption. -type OnionErrorDecrypter interface { - // DecryptError attempts to decrypt the passed encrypted error response. - // The onion failure is encrypted in backward manner, starting from the - // node where error have occurred. As a result, in order to decrypt the - // error we need get all shared secret and apply decryption in the - // reverse order. - DecryptError(encryptedData []byte) (*sphinx.DecryptedError, error) -} - // SphinxErrorDecrypter wraps the sphinx data SphinxErrorDecrypter and maps the // returned errors to concrete lnwire.FailureMessage instances. type SphinxErrorDecrypter struct { - OnionErrorDecrypter + decrypter *sphinx.OnionErrorDecrypter +} + +// NewSphinxErrorDecrypter instantiates a new error decrypter. +func NewSphinxErrorDecrypter(circuit *sphinx.Circuit) *SphinxErrorDecrypter { + return &SphinxErrorDecrypter{ + decrypter: sphinx.NewOnionErrorDecrypter( + circuit, hop.AttrErrorStruct, + ), + } } // DecryptError peels off each layer of onion encryption from the first hop, to @@ -174,23 +180,40 @@ type SphinxErrorDecrypter struct { // along with the source of the error. // // NOTE: Part of the ErrorDecrypter interface. -func (s *SphinxErrorDecrypter) DecryptError(reason lnwire.OpaqueReason) ( - *ForwardingError, error) { - - failure, err := s.OnionErrorDecrypter.DecryptError(reason) +func (s *SphinxErrorDecrypter) DecryptError(reason lnwire.OpaqueReason, + attrData []byte) (*ForwardingError, error) { + + // We do not set the strict attribution flag, as we want to account for + // the grace period during which nodes are still upgrading to support + // this feature. If set prematurely it can lead to early blame of our + // direct peers that may not support this feature yet, blacklisting our + // channels and failing our payments. + attrErr, err := s.decrypter.DecryptError(reason, attrData, false) if err != nil { return nil, err } + var holdTimeStrs []string + for _, ht := range attrErr.HoldTimes { + holdTimeStrs = append( + holdTimeStrs, fmt.Sprintf("%vms", ht*100), + ) + } + + log.Debugf("Extracted hold times from onion error: %v", + strings.Join(holdTimeStrs, "/")) + // Decode the failure. If an error occurs, we leave the failure message // field nil. - r := bytes.NewReader(failure.Message) + r := bytes.NewReader(attrErr.Message) failureMsg, err := lnwire.DecodeFailure(r, 0) if err != nil { - return NewUnknownForwardingError(failure.SenderIdx), nil + return NewUnknownForwardingError(attrErr.SenderIdx), nil } - return NewForwardingError(failureMsg, failure.SenderIdx), nil + return NewForwardingError( + failureMsg, attrErr.SenderIdx, attrErr.HoldTimes, + ), nil } // A compile time check to ensure ErrorDecrypter implements the Deobfuscator diff --git a/htlcswitch/failure_test.go b/htlcswitch/failure_test.go index 48ebc668210..d887be5cfb8 100644 --- a/htlcswitch/failure_test.go +++ b/htlcswitch/failure_test.go @@ -8,6 +8,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" sphinx "github.com/lightningnetwork/lightning-onion" + "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/tlv" "github.com/stretchr/testify/require" @@ -52,11 +53,13 @@ func TestLongFailureMessage(t *testing.T) { } errorDecryptor := &SphinxErrorDecrypter{ - OnionErrorDecrypter: sphinx.NewOnionErrorDecrypter(circuit), + decrypter: sphinx.NewOnionErrorDecrypter( + circuit, hop.AttrErrorStruct, + ), } // Assert that the failure message can still be extracted. - failure, err := errorDecryptor.DecryptError(reason) + failure, err := errorDecryptor.DecryptError(reason, nil) require.NoError(t, err) incorrectDetails, ok := failure.msg.(*lnwire.FailIncorrectDetails) diff --git a/htlcswitch/hop/error_encryptor.go b/htlcswitch/hop/error_encryptor.go index 23272ec00d4..7fdb47ee0fa 100644 --- a/htlcswitch/hop/error_encryptor.go +++ b/htlcswitch/hop/error_encryptor.go @@ -2,12 +2,15 @@ package hop import ( "bytes" + "errors" "fmt" "io" + "time" "github.com/btcsuite/btcd/btcec/v2" sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/tlv" ) // EncrypterType establishes an enum used in serialization to indicate how to @@ -37,6 +40,24 @@ const ( // the same functionality as a EncrypterTypeSphinx, but is used to mark // our special-case error handling. EncrypterTypeRelaying = 4 + + // A set of tlv type definitions used to serialize the encrypter to the + // database. + // + // NOTE: A migration should be added whenever this list changes. This + // prevents against the database being rolled back to an older + // format where the surrounding logic might assume a different set of + // fields are known. + creationTimeType tlv.Type = 0 +) + +// AttrErrorStruct defines the message structure for an attributable error. Use +// a maximum route length of 20, a fixed payload length of 4 bytes to +// accommodate the a 32-bit hold time in milliseconds and use 4 byte hmacs. +// Total size including a 256 byte message from the error source works out to +// 1200 bytes. +var ( + AttrErrorStruct = sphinx.NewAttrErrorStructure(20, 4, 4) ) // IsBlinded returns a boolean indicating whether the error encrypter belongs @@ -45,9 +66,9 @@ func (e EncrypterType) IsBlinded() bool { return e == EncrypterTypeIntroduction || e == EncrypterTypeRelaying } -// ErrorEncrypterExtracter defines a function signature that extracts an -// ErrorEncrypter from an sphinx OnionPacket. -type ErrorEncrypterExtracter func(*btcec.PublicKey) (ErrorEncrypter, +// SharedSecretGenerator defines a function signature that extracts a shared +// secret from an sphinx OnionPacket. +type SharedSecretGenerator func(*btcec.PublicKey) (sphinx.Hash256, lnwire.FailCode) // ErrorEncrypter is an interface that is used to encrypt HTLC related errors @@ -58,19 +79,22 @@ type ErrorEncrypter interface { // encrypted opaque failure reason. This method will be used at the // source that the error occurs. It differs from IntermediateEncrypt // slightly, in that it computes a proper MAC over the error. - EncryptFirstHop(lnwire.FailureMessage) (lnwire.OpaqueReason, error) + EncryptFirstHop(lnwire.FailureMessage) (lnwire.OpaqueReason, + []byte, error) // EncryptMalformedError is similar to EncryptFirstHop (it adds the // MAC), but it accepts an opaque failure reason rather than a failure // message. This method is used when we receive an // UpdateFailMalformedHTLC from the remote peer and then need to // convert that into a proper error from only the raw bytes. - EncryptMalformedError(lnwire.OpaqueReason) lnwire.OpaqueReason + EncryptMalformedError(lnwire.OpaqueReason) (lnwire.OpaqueReason, []byte, + error) // IntermediateEncrypt wraps an already encrypted opaque reason error // in an additional layer of onion encryption. This process repeats // until the error arrives at the source of the payment. - IntermediateEncrypt(lnwire.OpaqueReason) lnwire.OpaqueReason + IntermediateEncrypt(lnwire.OpaqueReason, []byte) (lnwire.OpaqueReason, + []byte, error) // Type returns an enum indicating the underlying concrete instance // backing this interface. @@ -84,12 +108,13 @@ type ErrorEncrypter interface { // given io.Reader. Decode(io.Reader) error - // Reextract rederives the encrypter using the extracter, performing an - // ECDH with the sphinx router's key and the ephemeral public key. + // Reextract rederives the encrypter using the shared secret generator, + // performing an ECDH with the sphinx router's key and the ephemeral + // public key. // // NOTE: This should be called shortly after Decode to properly // reinitialize the error encrypter. - Reextract(ErrorEncrypterExtracter) error + Reextract(SharedSecretGenerator) error } // SphinxErrorEncrypter is a concrete implementation of both the ErrorEncrypter @@ -100,20 +125,63 @@ type SphinxErrorEncrypter struct { *sphinx.OnionErrorEncrypter EphemeralKey *btcec.PublicKey + CreatedAt time.Time } -// NewSphinxErrorEncrypter initializes a blank sphinx error encrypter, that -// should be used to deserialize an encoded SphinxErrorEncrypter. Since the -// actual encrypter is not stored in plaintext while at rest, reconstructing the -// error encrypter requires: +// NewSphinxErrorEncrypterUninitialized initializes a blank sphinx error +// encrypter, that should be used to deserialize an encoded +// SphinxErrorEncrypter. Since the actual encrypter is not stored in plaintext +// while at rest, reconstructing the error encrypter requires: // 1. Decode: to deserialize the ephemeral public key. // 2. Reextract: to "unlock" the actual error encrypter using an active // OnionProcessor. -func NewSphinxErrorEncrypter() *SphinxErrorEncrypter { +func NewSphinxErrorEncrypterUninitialized() *SphinxErrorEncrypter { return &SphinxErrorEncrypter{ - OnionErrorEncrypter: nil, - EphemeralKey: &btcec.PublicKey{}, + EphemeralKey: &btcec.PublicKey{}, + } +} + +// NewSphinxErrorEncrypter creates a new instance of a SphinxErrorEncrypter, +// initialized with the provided shared secret. To deserialize an encoded +// SphinxErrorEncrypter, use the NewSphinxErrorEncrypterUninitialized +// constructor. +func NewSphinxErrorEncrypter(ephemeralKey *btcec.PublicKey, + sharedSecret sphinx.Hash256) *SphinxErrorEncrypter { + + encrypter := &SphinxErrorEncrypter{ + EphemeralKey: ephemeralKey, } + + // Set creation time rounded to nanosecond to avoid differences after + // serialization. + encrypter.CreatedAt = time.Now().Truncate(time.Nanosecond) + + encrypter.initialize(sharedSecret) + + return encrypter +} + +// getHoldTime returns the hold time in decaseconds since the first +// instantiation of this sphinx error encrypter. +func (s *SphinxErrorEncrypter) getHoldTime() uint32 { + return uint32(time.Since(s.CreatedAt).Milliseconds() / 100) +} + +// encryptWithHoldTime derives the hold time from the encrypter's creation +// timestamp and passes it to the underlying EncryptError method. +func (s *SphinxErrorEncrypter) encryptWithHoldTime(initial bool, + data, attrData []byte) (lnwire.OpaqueReason, []byte, error) { + + holdTime := s.getHoldTime() + + return s.EncryptError(initial, data, attrData, holdTime) +} + +// initialize creates the underlying instance of the sphinx error encrypter. +func (s *SphinxErrorEncrypter) initialize(sharedSecret sphinx.Hash256) { + s.OnionErrorEncrypter = sphinx.NewOnionErrorEncrypter( + sharedSecret, AttrErrorStruct, + ) } // EncryptFirstHop transforms a concrete failure message into an encrypted @@ -123,16 +191,14 @@ func NewSphinxErrorEncrypter() *SphinxErrorEncrypter { // // NOTE: Part of the ErrorEncrypter interface. func (s *SphinxErrorEncrypter) EncryptFirstHop( - failure lnwire.FailureMessage) (lnwire.OpaqueReason, error) { + failure lnwire.FailureMessage) (lnwire.OpaqueReason, []byte, error) { var b bytes.Buffer if err := lnwire.EncodeFailure(&b, failure, 0); err != nil { - return nil, err + return nil, nil, err } - // We pass a true as the first parameter to indicate that a MAC should - // be added. - return s.EncryptError(true, b.Bytes()), nil + return s.encryptWithHoldTime(true, b.Bytes(), nil) } // EncryptMalformedError is similar to EncryptFirstHop (it adds the MAC), but @@ -143,9 +209,9 @@ func (s *SphinxErrorEncrypter) EncryptFirstHop( // // NOTE: Part of the ErrorEncrypter interface. func (s *SphinxErrorEncrypter) EncryptMalformedError( - reason lnwire.OpaqueReason) lnwire.OpaqueReason { + reason lnwire.OpaqueReason) (lnwire.OpaqueReason, []byte, error) { - return s.EncryptError(true, reason) + return s.encryptWithHoldTime(true, reason, nil) } // IntermediateEncrypt wraps an already encrypted opaque reason error in an @@ -156,9 +222,27 @@ func (s *SphinxErrorEncrypter) EncryptMalformedError( // // NOTE: Part of the ErrorEncrypter interface. func (s *SphinxErrorEncrypter) IntermediateEncrypt( - reason lnwire.OpaqueReason) lnwire.OpaqueReason { + reason lnwire.OpaqueReason, attrData []byte) (lnwire.OpaqueReason, + []byte, error) { + + encrypted, attrData, err := s.encryptWithHoldTime( + false, reason, attrData, + ) + + switch { + // If the structure of the error received from downstream is invalid, + // then generate a new attribution structure so that the sender is able + // to penalize the offending node. + case errors.Is(err, sphinx.ErrInvalidAttrStructure): + // Preserve the error message and initialize fresh attribution + // data. + return s.encryptWithHoldTime(true, reason, nil) + + case err != nil: + return lnwire.OpaqueReason{}, nil, err + } - return s.EncryptError(false, reason) + return encrypted, attrData, nil } // Type returns the identifier for a sphinx error encrypter. @@ -171,7 +255,20 @@ func (s *SphinxErrorEncrypter) Type() EncrypterType { func (s *SphinxErrorEncrypter) Encode(w io.Writer) error { ephemeral := s.EphemeralKey.SerializeCompressed() _, err := w.Write(ephemeral) - return err + if err != nil { + return err + } + + creationTime := uint64(s.CreatedAt.UnixNano()) + + tlvStream, err := tlv.NewStream( + tlv.MakePrimitiveRecord(creationTimeType, &creationTime), + ) + if err != nil { + return err + } + + return tlvStream.Encode(w) } // Decode reconstructs the error encrypter's ephemeral public key from the @@ -188,16 +285,37 @@ func (s *SphinxErrorEncrypter) Decode(r io.Reader) error { return err } + // Try decode attributable error structure. + var creationTime uint64 + + tlvStream, err := tlv.NewStream( + tlv.MakePrimitiveRecord(creationTimeType, &creationTime), + ) + if err != nil { + return err + } + + typeMap, err := tlvStream.DecodeWithParsedTypes(r) + if err != nil { + return err + } + + // Return early if this encrypter is not for attributable errors. + if len(typeMap) == 0 { + return nil + } + + // Set attributable error creation time. + s.CreatedAt = time.Unix(0, int64(creationTime)) + return nil } // Reextract rederives the error encrypter from the currently held EphemeralKey. // This intended to be used shortly after Decode, to fully initialize a // SphinxErrorEncrypter. -func (s *SphinxErrorEncrypter) Reextract( - extract ErrorEncrypterExtracter) error { - - obfuscator, failcode := extract(s.EphemeralKey) +func (s *SphinxErrorEncrypter) Reextract(extract SharedSecretGenerator) error { + sharedSecret, failcode := extract(s.EphemeralKey) if failcode != lnwire.CodeNone { // This should never happen, since we already validated that // this obfuscator can be extracted when it was received in the @@ -206,13 +324,7 @@ func (s *SphinxErrorEncrypter) Reextract( "obfuscator, got failcode: %d", failcode) } - sphinxEncrypter, ok := obfuscator.(*SphinxErrorEncrypter) - if !ok { - return fmt.Errorf("incorrect onion error extracter") - } - - // Copy the freshly extracted encrypter. - s.OnionErrorEncrypter = sphinxEncrypter.OnionErrorEncrypter + s.initialize(sharedSecret) return nil } @@ -235,9 +347,25 @@ type IntroductionErrorEncrypter struct { } // NewIntroductionErrorEncrypter returns a blank IntroductionErrorEncrypter. -func NewIntroductionErrorEncrypter() *IntroductionErrorEncrypter { +func NewIntroductionErrorEncrypter(ephemeralKey *btcec.PublicKey, + sharedSecret sphinx.Hash256) *IntroductionErrorEncrypter { + + return &IntroductionErrorEncrypter{ + ErrorEncrypter: NewSphinxErrorEncrypter( + ephemeralKey, sharedSecret, + ), + } +} + +// NewIntroductionErrorEncrypter returns a blank IntroductionErrorEncrypter. +// Since the actual encrypter is not stored in plaintext +// while at rest, reconstructing the error encrypter requires: +// 1. Decode: to deserialize the ephemeral public key. +// 2. Reextract: to "unlock" the actual error encrypter using an active +// OnionProcessor. +func NewIntroductionErrorEncrypterUninitialized() *IntroductionErrorEncrypter { return &IntroductionErrorEncrypter{ - ErrorEncrypter: NewSphinxErrorEncrypter(), + ErrorEncrypter: NewSphinxErrorEncrypterUninitialized(), } } @@ -249,7 +377,7 @@ func (i *IntroductionErrorEncrypter) Type() EncrypterType { // Reextract rederives the error encrypter from the currently held EphemeralKey, // relying on the logic in the underlying SphinxErrorEncrypter. func (i *IntroductionErrorEncrypter) Reextract( - extract ErrorEncrypterExtracter) error { + extract SharedSecretGenerator) error { return i.ErrorEncrypter.Reextract(extract) } @@ -266,9 +394,26 @@ type RelayingErrorEncrypter struct { // NewRelayingErrorEncrypter returns a blank RelayingErrorEncrypter with // an underlying SphinxErrorEncrypter. -func NewRelayingErrorEncrypter() *RelayingErrorEncrypter { +func NewRelayingErrorEncrypter(ephemeralKey *btcec.PublicKey, + sharedSecret sphinx.Hash256) *RelayingErrorEncrypter { + + return &RelayingErrorEncrypter{ + ErrorEncrypter: NewSphinxErrorEncrypter( + ephemeralKey, sharedSecret, + ), + } +} + +// NewRelayingErrorEncrypterUninitialized returns a blank RelayingErrorEncrypter +// with an underlying SphinxErrorEncrypter. +// Since the actual encrypter is not stored in plaintext +// while at rest, reconstructing the error encrypter requires: +// 1. Decode: to deserialize the ephemeral public key. +// 2. Reextract: to "unlock" the actual error encrypter using an active +// OnionProcessor. +func NewRelayingErrorEncrypterUninitialized() *RelayingErrorEncrypter { return &RelayingErrorEncrypter{ - ErrorEncrypter: NewSphinxErrorEncrypter(), + ErrorEncrypter: NewSphinxErrorEncrypterUninitialized(), } } @@ -280,7 +425,7 @@ func (r *RelayingErrorEncrypter) Type() EncrypterType { // Reextract rederives the error encrypter from the currently held EphemeralKey, // relying on the logic in the underlying SphinxErrorEncrypter. func (r *RelayingErrorEncrypter) Reextract( - extract ErrorEncrypterExtracter) error { + extract SharedSecretGenerator) error { return r.ErrorEncrypter.Reextract(extract) } diff --git a/htlcswitch/hop/error_encryptor_test.go b/htlcswitch/hop/error_encryptor_test.go new file mode 100644 index 00000000000..171bdd525d8 --- /dev/null +++ b/htlcswitch/hop/error_encryptor_test.go @@ -0,0 +1,188 @@ +package hop + +import ( + "bytes" + "testing" + "time" + + "github.com/btcsuite/btcd/btcec/v2" + sphinx "github.com/lightningnetwork/lightning-onion" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/stretchr/testify/require" +) + +// makeTestEncrypter creates a SphinxErrorEncrypter with a deterministic key +// and shared secret for testing. +func makeTestEncrypter(t *testing.T) (*SphinxErrorEncrypter, + *btcec.PublicKey, sphinx.Hash256) { + + t.Helper() + + privKey, err := btcec.NewPrivateKey() + require.NoError(t, err) + + ephemeralKey := privKey.PubKey() + sharedSecret := sphinx.Hash256{1, 2, 3, 4, 5} + + enc := NewSphinxErrorEncrypter(ephemeralKey, sharedSecret) + + return enc, ephemeralKey, sharedSecret +} + +// TestSphinxErrorEncrypterEncodeDecode verifies that a SphinxErrorEncrypter +// round-trips through Encode/Decode, preserving the ephemeral key and +// CreatedAt timestamp. +func TestSphinxErrorEncrypterEncodeDecode(t *testing.T) { + t.Parallel() + + enc, ephemeralKey, _ := makeTestEncrypter(t) + + // Encode. + var buf bytes.Buffer + require.NoError(t, enc.Encode(&buf)) + + // Decode into a fresh encrypter. + dec := NewSphinxErrorEncrypterUninitialized() + require.NoError(t, dec.Decode(&buf)) + + // The ephemeral key should match. + require.True(t, ephemeralKey.IsEqual(dec.EphemeralKey), + "ephemeral keys don't match after round-trip") + + // The creation time should match. + require.True(t, enc.CreatedAt.Equal(dec.CreatedAt), + "creation times don't match: got %v, want %v", + dec.CreatedAt, enc.CreatedAt) +} + +// TestSphinxErrorEncrypterDecodeBackwardCompat verifies that Decode can handle +// data that was encoded without the TLV creation time (i.e. pre-attributable +// errors format). In that case, CreatedAt should remain zero. +func TestSphinxErrorEncrypterDecodeBackwardCompat(t *testing.T) { + t.Parallel() + + privKey, err := btcec.NewPrivateKey() + require.NoError(t, err) + + // Encode just the compressed ephemeral key with no TLV suffix. + var buf bytes.Buffer + ephemeral := privKey.PubKey().SerializeCompressed() + _, err = buf.Write(ephemeral) + require.NoError(t, err) + + dec := NewSphinxErrorEncrypterUninitialized() + require.NoError(t, dec.Decode(&buf)) + + require.True(t, privKey.PubKey().IsEqual(dec.EphemeralKey)) + require.True(t, dec.CreatedAt.IsZero(), + "expected zero CreatedAt for legacy encoding") +} + +// TestGetHoldTime verifies the hold time computation. +func TestGetHoldTime(t *testing.T) { + t.Parallel() + + enc, _, _ := makeTestEncrypter(t) + + // Immediately after creation, hold time should be 0 (less than 100ms). + holdTime := enc.getHoldTime() + require.Zero(t, holdTime, + "hold time should be 0 immediately after creation") + + // Set creation time 500ms in the past. + enc.CreatedAt = time.Now().Add(-500 * time.Millisecond) + holdTime = enc.getHoldTime() + require.InDelta(t, 5, holdTime, 1, + "hold time should be ~5 for 500ms elapsed") + + // Set creation time 2 seconds in the past. + enc.CreatedAt = time.Now().Add(-2 * time.Second) + holdTime = enc.getHoldTime() + require.InDelta(t, 20, holdTime, 1, + "hold time should be ~20 for 2s elapsed") +} + +// TestSphinxErrorEncrypterReextract verifies that Reextract properly +// reinitializes the encrypter after Decode. +func TestSphinxErrorEncrypterReextract(t *testing.T) { + t.Parallel() + + enc, _, sharedSecret := makeTestEncrypter(t) + + // Encode. + var buf bytes.Buffer + require.NoError(t, enc.Encode(&buf)) + + // Decode. + dec := NewSphinxErrorEncrypterUninitialized() + require.NoError(t, dec.Decode(&buf)) + + // At this point the OnionErrorEncrypter is nil. + require.Nil(t, dec.OnionErrorEncrypter) + + // Reextract should re-initialize it. + err := dec.Reextract(func(key *btcec.PublicKey) (sphinx.Hash256, + lnwire.FailCode) { + + return sharedSecret, lnwire.CodeNone + }) + require.NoError(t, err) + require.NotNil(t, dec.OnionErrorEncrypter) +} + +// TestIntroductionErrorEncrypterEncodeDecode verifies round-trip for +// IntroductionErrorEncrypter. +func TestIntroductionErrorEncrypterEncodeDecode(t *testing.T) { + t.Parallel() + + privKey, err := btcec.NewPrivateKey() + require.NoError(t, err) + + sharedSecret := sphinx.Hash256{10, 20, 30} + + enc := NewIntroductionErrorEncrypter(privKey.PubKey(), sharedSecret) + + var buf bytes.Buffer + require.NoError(t, enc.Encode(&buf)) + + dec := NewIntroductionErrorEncrypterUninitialized() + require.NoError(t, dec.Decode(&buf)) + + // Access the underlying SphinxErrorEncrypter via type assertion. + encInner, ok := enc.ErrorEncrypter.(*SphinxErrorEncrypter) + require.True(t, ok) + decInner, ok := dec.ErrorEncrypter.(*SphinxErrorEncrypter) + require.True(t, ok) + + require.True(t, privKey.PubKey().IsEqual(decInner.EphemeralKey)) + require.True(t, encInner.CreatedAt.Equal(decInner.CreatedAt)) + require.Equal(t, EncrypterType(EncrypterTypeIntroduction), dec.Type()) +} + +// TestRelayingErrorEncrypterEncodeDecode verifies round-trip for +// RelayingErrorEncrypter. +func TestRelayingErrorEncrypterEncodeDecode(t *testing.T) { + t.Parallel() + + privKey, err := btcec.NewPrivateKey() + require.NoError(t, err) + + sharedSecret := sphinx.Hash256{10, 20, 30} + + enc := NewRelayingErrorEncrypter(privKey.PubKey(), sharedSecret) + + var buf bytes.Buffer + require.NoError(t, enc.Encode(&buf)) + + dec := NewRelayingErrorEncrypterUninitialized() + require.NoError(t, dec.Decode(&buf)) + + encInner, ok := enc.ErrorEncrypter.(*SphinxErrorEncrypter) + require.True(t, ok) + decInner, ok := dec.ErrorEncrypter.(*SphinxErrorEncrypter) + require.True(t, ok) + + require.True(t, privKey.PubKey().IsEqual(decInner.EphemeralKey)) + require.True(t, encInner.CreatedAt.Equal(decInner.CreatedAt)) + require.Equal(t, EncrypterType(EncrypterTypeRelaying), dec.Type()) +} diff --git a/htlcswitch/hop/iterator.go b/htlcswitch/hop/iterator.go index 553c4921dbc..06bea3f0012 100644 --- a/htlcswitch/hop/iterator.go +++ b/htlcswitch/hop/iterator.go @@ -102,10 +102,11 @@ type Iterator interface { // into the passed io.Writer. EncodeNextHop(w io.Writer) error - // ExtractErrorEncrypter returns the ErrorEncrypter needed for this hop, - // along with a failure code to signal if the decoding was successful. - ExtractErrorEncrypter(extractor ErrorEncrypterExtracter, - introductionNode bool) (ErrorEncrypter, lnwire.FailCode) + // ExtractEncrypterParams extracts the ephemeral key and shared secret + // from the onion packet and returns them to the caller along with a + // failure code to signal if the decoding was successful. + ExtractEncrypterParams(SharedSecretGenerator) (*btcec.PublicKey, + sphinx.Hash256, lnwire.BlindingPointRecord, lnwire.FailCode) } // sphinxHopIterator is the Sphinx implementation of hop iterator which uses @@ -482,38 +483,23 @@ func parseAndValidateSenderPayload(payloadBytes []byte, isFinalHop, return payload, routeRole, true, nil } -// ExtractErrorEncrypter decodes and returns the ErrorEncrypter for this hop, -// along with a failure code to signal if the decoding was successful. The -// ErrorEncrypter is used to encrypt errors back to the sender in the event that -// a payment fails. +// ExtractEncrypterParams extracts the ephemeral key, shared secret and blinding +// point record from the onion packet and returns them to the caller along with +// a failure code to signal if the decoding was successful. // // NOTE: Part of the HopIterator interface. -func (r *sphinxHopIterator) ExtractErrorEncrypter( - extracter ErrorEncrypterExtracter, introductionNode bool) ( - ErrorEncrypter, lnwire.FailCode) { +func (r *sphinxHopIterator) ExtractEncrypterParams( + extracter SharedSecretGenerator) (*btcec.PublicKey, sphinx.Hash256, + lnwire.BlindingPointRecord, lnwire.FailCode) { - encrypter, errCode := extracter(r.ogPacket.EphemeralKey) - if errCode != lnwire.CodeNone { - return nil, errCode + sharedSecret, failCode := extracter(r.ogPacket.EphemeralKey) + if failCode != lnwire.CodeNone { + return nil, sphinx.Hash256{}, r.blindingKit.UpdateAddBlinding, + failCode } - // If we're in a blinded path, wrap the error encrypter that we just - // derived in a "marker" type which we'll use to know what type of - // error we're handling. - switch { - case introductionNode: - return &IntroductionErrorEncrypter{ - ErrorEncrypter: encrypter, - }, errCode - - case r.blindingKit.UpdateAddBlinding.IsSome(): - return &RelayingErrorEncrypter{ - ErrorEncrypter: encrypter, - }, errCode - - default: - return encrypter, errCode - } + return r.ogPacket.EphemeralKey, sharedSecret, + r.blindingKit.UpdateAddBlinding, lnwire.CodeNone } // BlindingProcessor is an interface that provides the cryptographic operations @@ -901,33 +887,26 @@ func (p *OnionProcessor) DecodeHopIterators(id []byte, return resps, nil } -// ExtractErrorEncrypter takes an io.Reader which should contain the onion -// packet as original received by a forwarding node and creates an -// ErrorEncrypter instance using the derived shared secret. In the case that en -// error occurs, a lnwire failure code detailing the parsing failure will be -// returned. -func (p *OnionProcessor) ExtractErrorEncrypter(ephemeralKey *btcec.PublicKey) ( - ErrorEncrypter, lnwire.FailCode) { +// ExtractSharedSecret takes an ephemeral session key as original received by a +// forwarding node and generates the shared secret. In the case that an error +// occurs, a lnwire failure code detailing the parsing failure will be returned. +func (p *OnionProcessor) ExtractSharedSecret(ephemeralKey *btcec.PublicKey) ( + sphinx.Hash256, lnwire.FailCode) { - onionObfuscator, err := sphinx.NewOnionErrorEncrypter( - p.router, ephemeralKey, - ) + sharedSecret, err := p.router.GenerateSharedSecret(ephemeralKey, nil) if err != nil { switch err { case sphinx.ErrInvalidOnionVersion: - return nil, lnwire.CodeInvalidOnionVersion + return sphinx.Hash256{}, lnwire.CodeInvalidOnionVersion case sphinx.ErrInvalidOnionHMAC: - return nil, lnwire.CodeInvalidOnionHmac + return sphinx.Hash256{}, lnwire.CodeInvalidOnionHmac case sphinx.ErrInvalidOnionKey: - return nil, lnwire.CodeInvalidOnionKey + return sphinx.Hash256{}, lnwire.CodeInvalidOnionKey default: log.Errorf("unable to process onion packet: %v", err) - return nil, lnwire.CodeInvalidOnionKey + return sphinx.Hash256{}, lnwire.CodeInvalidOnionKey } } - return &SphinxErrorEncrypter{ - OnionErrorEncrypter: onionObfuscator, - EphemeralKey: ephemeralKey, - }, lnwire.CodeNone + return sharedSecret, lnwire.CodeNone } diff --git a/htlcswitch/interceptable_switch.go b/htlcswitch/interceptable_switch.go index 3d0bd90ed45..5dd10af7d9d 100644 --- a/htlcswitch/interceptable_switch.go +++ b/htlcswitch/interceptable_switch.go @@ -738,7 +738,12 @@ func (f *interceptedForward) ResumeModified( // Fail notifies the intention to Fail an existing hold forward with an // encrypted failure reason. func (f *interceptedForward) Fail(reason []byte) error { - obfuscatedReason := f.packet.obfuscator.IntermediateEncrypt(reason) + obfuscatedReason, _, err := f.packet.obfuscator.IntermediateEncrypt( + reason, nil, + ) + if err != nil { + return err + } return f.resolve(&lnwire.UpdateFailHTLC{ Reason: obfuscatedReason, @@ -804,13 +809,19 @@ func (f *interceptedForward) FailWithCode(code lnwire.FailCode) error { // Encrypt the failure for the first hop. This node will be the origin // of the failure. - reason, err := f.packet.obfuscator.EncryptFirstHop(failureMsg) + reason, attrData, err := f.packet.obfuscator.EncryptFirstHop(failureMsg) if err != nil { return fmt.Errorf("failed to encrypt failure reason %w", err) } + extraData, err := lnwire.AttrDataToExtraData(attrData) + if err != nil { + return err + } + return f.resolve(&lnwire.UpdateFailHTLC{ - Reason: reason, + Reason: reason, + ExtraData: extraData, }) } diff --git a/htlcswitch/link.go b/htlcswitch/link.go index 1db005bf82f..da886995bc4 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -12,9 +12,11 @@ import ( "sync/atomic" "time" + "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btclog/v2" + sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/contractcourt" "github.com/lightningnetwork/lnd/fn/v2" @@ -111,9 +113,15 @@ type ChannelLinkConfig struct { DecodeHopIterators func([]byte, []hop.DecodeHopIteratorRequest, bool) ( []hop.DecodeHopIteratorResponse, error) - // ExtractErrorEncrypter function is responsible for decoding HTLC - // Sphinx onion blob, and creating onion failure obfuscator. - ExtractErrorEncrypter hop.ErrorEncrypterExtracter + // ExtractSharedSecret function is responsible for decoding HTLC + // Sphinx onion blob, and deriving the shared secret. + ExtractSharedSecret hop.SharedSecretGenerator + + // CreateErrorEncrypter instantiates an error encrypter based on the + // provided encryption parameters. + CreateErrorEncrypter func(ephemeralKey *btcec.PublicKey, + sharedSecret sphinx.Hash256, isIntroduction, + hasBlindingPoint bool) hop.ErrorEncrypter // FetchLastChannelUpdate retrieves the latest routing policy for a // target channel. This channel will typically be the outgoing channel @@ -3057,19 +3065,11 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg) { failedType = uint64(e.Type) } - // If we couldn't parse the payload, make our best - // effort at creating an error encrypter that knows - // what blinding type we were, but if we couldn't - // parse the payload we have no way of knowing whether - // we were the introduction node or not. - // - //nolint:ll - obfuscator, failCode := chanIterator.ExtractErrorEncrypter( - l.cfg.ExtractErrorEncrypter, - // We need our route role here because we - // couldn't parse or validate the payload. - routeRole == hop.RouteRoleIntroduction, - ) + // Let's extract the error encrypter parameters. + ephemeralKey, sharedSecret, blindingPoint, failCode := + chanIterator.ExtractEncrypterParams( + l.cfg.ExtractSharedSecret, + ) if failCode != lnwire.CodeNone { l.log.Errorf("could not extract error "+ "encrypter: %v", pldErr) @@ -3084,6 +3084,21 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg) { continue } + // If we couldn't parse the payload, make our best + // effort at creating an error encrypter that knows + // what blinding type we were, but if we couldn't + // parse the payload we have no way of knowing whether + // we were the introduction node or not. Let's create + // the error encrypter based on the extracted encryption + // parameters. + obfuscator := l.cfg.CreateErrorEncrypter( + ephemeralKey, sharedSecret, + // We need our route role here because we + // couldn't parse or validate the payload. + routeRole == hop.RouteRoleIntroduction, + blindingPoint.IsSome(), + ) + // TODO: currently none of the test unit infrastructure // is setup to handle TLV payloads, so testing this // would require implementing a separate mock iterator @@ -3103,12 +3118,11 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg) { continue } - // Retrieve onion obfuscator from onion blob in order to - // produce initial obfuscation of the onion failureCode. - obfuscator, failureCode := chanIterator.ExtractErrorEncrypter( - l.cfg.ExtractErrorEncrypter, - routeRole == hop.RouteRoleIntroduction, - ) + // Extract the encryption parameters. + ephemeralKey, sharedSecret, blindingPoint, failureCode := + chanIterator.ExtractEncrypterParams( + l.cfg.ExtractSharedSecret, + ) if failureCode != lnwire.CodeNone { // If we're unable to process the onion blob than we // should send the malformed htlc error to payment @@ -3124,6 +3138,14 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg) { continue } + // Instantiate an error encrypter based on the extracted + // encryption parameters. + obfuscator := l.cfg.CreateErrorEncrypter( + ephemeralKey, sharedSecret, + routeRole == hop.RouteRoleIntroduction, + blindingPoint.IsSome(), + ) + fwdInfo := pld.ForwardingInfo() // Check whether the payload we've just processed uses our @@ -3562,13 +3584,20 @@ func (l *channelLink) sendHTLCError(add lnwire.UpdateAddHTLC, sourceRef channeldb.AddRef, failure *LinkError, e hop.ErrorEncrypter, isReceive bool) { - reason, err := e.EncryptFirstHop(failure.WireMessage()) + reason, attrData, err := e.EncryptFirstHop(failure.WireMessage()) if err != nil { l.log.Errorf("unable to obfuscate error: %v", err) return } - err = l.channel.FailHTLC(add.ID, reason, &sourceRef, nil, nil) + extraData, err := lnwire.AttrDataToExtraData(attrData) + if err != nil { + return + } + + err = l.channel.FailHTLC( + add.ID, reason, extraData, &sourceRef, nil, nil, + ) if err != nil { l.log.Errorf("unable cancel htlc: %v", err) return @@ -3577,7 +3606,7 @@ func (l *channelLink) sendHTLCError(add lnwire.UpdateAddHTLC, // Send the appropriate failure message depending on whether we're // in a blinded route or not. if err := l.sendIncomingHTLCFailureMsg( - add.ID, e, reason, + add.ID, e, reason, extraData, ); err != nil { l.log.Errorf("unable to send HTLC failure: %v", err) return @@ -3621,8 +3650,8 @@ func (l *channelLink) sendHTLCError(add lnwire.UpdateAddHTLC, // used if we are the introduction node and need to present an error as if // we're the failing party. func (l *channelLink) sendIncomingHTLCFailureMsg(htlcIndex uint64, - e hop.ErrorEncrypter, - originalFailure lnwire.OpaqueReason) error { + e hop.ErrorEncrypter, originalFailure lnwire.OpaqueReason, + extraData lnwire.ExtraOpaqueData) error { var msg lnwire.Message switch { @@ -3635,9 +3664,10 @@ func (l *channelLink) sendIncomingHTLCFailureMsg(htlcIndex uint64, // code. case e == nil: msg = &lnwire.UpdateFailHTLC{ - ChanID: l.ChanID(), - ID: htlcIndex, - Reason: originalFailure, + ChanID: l.ChanID(), + ID: htlcIndex, + Reason: originalFailure, + ExtraData: extraData, } l.log.Errorf("Unexpected blinded failure when "+ @@ -3648,9 +3678,10 @@ func (l *channelLink) sendIncomingHTLCFailureMsg(htlcIndex uint64, // transformation on the error message and can just send the original. case !e.Type().IsBlinded(): msg = &lnwire.UpdateFailHTLC{ - ChanID: l.ChanID(), - ID: htlcIndex, - Reason: originalFailure, + ChanID: l.ChanID(), + ID: htlcIndex, + Reason: originalFailure, + ExtraData: extraData, } // When we're the introduction node, we need to convert the error to @@ -3664,7 +3695,7 @@ func (l *channelLink) sendIncomingHTLCFailureMsg(htlcIndex uint64, failureMsg := lnwire.NewInvalidBlinding( fn.None[[lnwire.OnionPacketSize]byte](), ) - reason, err := e.EncryptFirstHop(failureMsg) + reason, _, err := e.EncryptFirstHop(failureMsg) if err != nil { return err } @@ -4185,7 +4216,7 @@ func (l *channelLink) processRemoteUpdateFailMalformedHTLC( // If remote side have been unable to parse the onion blob we have sent // to it, than we should transform the malformed HTLC message to the // usual HTLC fail message. - err := l.channel.ReceiveFailHTLC(msg.ID, b.Bytes()) + err := l.channel.ReceiveFailHTLC(msg.ID, b.Bytes(), msg.ExtraData) if err != nil { l.failf(LinkFailureError{code: ErrInvalidUpdate}, "unable to handle upstream fail HTLC: %v", err) @@ -4226,7 +4257,7 @@ func (l *channelLink) processRemoteUpdateFailHTLC( // Add fail to the update log. idx := msg.ID - err := l.channel.ReceiveFailHTLC(idx, msg.Reason[:]) + err := l.channel.ReceiveFailHTLC(idx, msg.Reason[:], msg.ExtraData) if err != nil { l.failf(LinkFailureError{code: ErrInvalidUpdate}, "unable to handle upstream fail HTLC: %v", err) @@ -4666,8 +4697,8 @@ func (l *channelLink) processLocalUpdateFailHTLC(ctx context.Context, // remove then HTLC from our local state machine. inKey := pkt.inKey() err := l.channel.FailHTLC( - pkt.incomingHTLCID, htlc.Reason, pkt.sourceRef, pkt.destRef, - &inKey, + pkt.incomingHTLCID, htlc.Reason, htlc.ExtraData, pkt.sourceRef, + pkt.destRef, &inKey, ) if err != nil { l.log.Errorf("unable to cancel incoming HTLC for "+ @@ -4703,7 +4734,9 @@ func (l *channelLink) processLocalUpdateFailHTLC(ctx context.Context, // HTLC. If the incoming blinding point is non-nil, we know that we are // a relaying node in a blinded path. Otherwise, we're either an // introduction node or not part of a blinded path at all. - err = l.sendIncomingHTLCFailureMsg(htlc.ID, pkt.obfuscator, htlc.Reason) + err = l.sendIncomingHTLCFailureMsg( + htlc.ID, pkt.obfuscator, htlc.Reason, htlc.ExtraData, + ) if err != nil { l.log.Errorf("unable to send HTLC failure: %v", err) diff --git a/htlcswitch/link_isolated_test.go b/htlcswitch/link_isolated_test.go index 9e74c487580..e2a43c6633d 100644 --- a/htlcswitch/link_isolated_test.go +++ b/htlcswitch/link_isolated_test.go @@ -254,7 +254,7 @@ func (l *linkTestContext) receiveFailAliceToBob() { l.t.Fatalf("expected UpdateFailHTLC, got %T", msg) } - err := l.bobChannel.ReceiveFailHTLC(failMsg.ID, failMsg.Reason) + err := l.bobChannel.ReceiveFailHTLC(failMsg.ID, failMsg.Reason, nil) if err != nil { l.t.Fatalf("unable to apply received fail htlc: %v", err) } diff --git a/htlcswitch/link_test.go b/htlcswitch/link_test.go index 29b4f902d0b..d22c60a6922 100644 --- a/htlcswitch/link_test.go +++ b/htlcswitch/link_test.go @@ -1802,9 +1802,10 @@ func TestChannelLinkMultiHopDecodeError(t *testing.T) { t.Cleanup(n.stop) // Replace decode function with another which throws an error. - n.carolChannelLink.cfg.ExtractErrorEncrypter = func( - *btcec.PublicKey) (hop.ErrorEncrypter, lnwire.FailCode) { - return nil, lnwire.CodeInvalidOnionVersion + n.carolChannelLink.cfg.ExtractSharedSecret = func( + *btcec.PublicKey) (sphinx.Hash256, lnwire.FailCode) { + + return sphinx.Hash256{}, lnwire.CodeInvalidOnionVersion } carolBandwidthBefore := n.carolChannelLink.Bandwidth() @@ -2213,9 +2214,15 @@ func newSingleLinkTestHarness(t *testing.T, chanAmt, Circuits: aliceSwitch.CircuitModifier(), ForwardPackets: forwardPackets, DecodeHopIterators: decoder.DecodeHopIterators, - ExtractErrorEncrypter: func(*btcec.PublicKey) ( - hop.ErrorEncrypter, lnwire.FailCode) { - return obfuscator, lnwire.CodeNone + ExtractSharedSecret: func(*btcec.PublicKey) ( + sphinx.Hash256, lnwire.FailCode) { + + return sphinx.Hash256{}, lnwire.CodeNone + }, + CreateErrorEncrypter: func(*btcec.PublicKey, + sphinx.Hash256, bool, bool) hop.ErrorEncrypter { + + return obfuscator }, FetchLastChannelUpdate: mockGetChanUpdateMessage, PreimageCache: pCache, @@ -2672,7 +2679,7 @@ func TestChannelLinkBandwidthConsistency(t *testing.T) { reason := make([]byte, 292) copy(reason, []byte("nop")) - err = harness.bobChannel.FailHTLC(bobIndex, reason, nil, nil, nil) + err = harness.bobChannel.FailHTLC(bobIndex, reason, nil, nil, nil, nil) require.NoError(t, err, "unable to fail htlc") failMsg := &lnwire.UpdateFailHTLC{ ID: 1, @@ -2919,7 +2926,9 @@ func TestChannelLinkBandwidthConsistency(t *testing.T) { if !ok { t.Fatalf("expected UpdateFailHTLC, got %T", msg) } - err = harness.bobChannel.ReceiveFailHTLC(failMsg.ID, []byte("fail")) + err = harness.bobChannel.ReceiveFailHTLC( + failMsg.ID, []byte("fail"), nil, + ) require.NoError(t, err, "failed receiving fail htlc") // After failing an HTLC, the link will automatically trigger @@ -4898,10 +4907,15 @@ func (h *persistentLinkHarness) restartLink( Circuits: h.hSwitch.CircuitModifier(), ForwardPackets: forwardPackets, DecodeHopIterators: decoder.DecodeHopIterators, - ExtractErrorEncrypter: func(*btcec.PublicKey) ( - hop.ErrorEncrypter, lnwire.FailCode) { + ExtractSharedSecret: func(*btcec.PublicKey) ( + sphinx.Hash256, lnwire.FailCode) { + + return sphinx.Hash256{}, lnwire.CodeNone + }, + CreateErrorEncrypter: func(*btcec.PublicKey, + sphinx.Hash256, bool, bool) hop.ErrorEncrypter { - return obfuscator, lnwire.CodeNone + return obfuscator }, FetchLastChannelUpdate: mockGetChanUpdateMessage, PreimageCache: pCache, @@ -7300,7 +7314,7 @@ func TestChannelLinkShortFailureRelay(t *testing.T) { // Return a short htlc failure from Bob to Alice and lock in. shortReason := make([]byte, 260) - err = harness.bobChannel.FailHTLC(0, shortReason, nil, nil, nil) + err = harness.bobChannel.FailHTLC(0, shortReason, nil, nil, nil, nil) require.NoError(t, err) harness.aliceLink.HandleChannelUpdate(&lnwire.UpdateFailHTLC{ diff --git a/htlcswitch/mailbox.go b/htlcswitch/mailbox.go index b283825dd96..8af7ddfb330 100644 --- a/htlcswitch/mailbox.go +++ b/htlcswitch/mailbox.go @@ -697,6 +697,7 @@ func (m *memoryMailBox) FailAdd(pkt *htlcPacket) { var ( localFailure = false reason lnwire.OpaqueReason + attrData []byte ) // Create a temporary channel failure which we will send back to our @@ -721,13 +722,19 @@ func (m *memoryMailBox) FailAdd(pkt *htlcPacket) { // If the packet is part of a forward, (identified by a non-nil // obfuscator) we need to encrypt the error back to the source. var err error - reason, err = pkt.obfuscator.EncryptFirstHop(failure) + reason, attrData, err = pkt.obfuscator.EncryptFirstHop(failure) if err != nil { log.Errorf("Unable to obfuscate error: %v", err) return } } + extraData, err := lnwire.AttrDataToExtraData(attrData) + if err != nil { + log.Errorf("Failed to convert attr data: %v", err) + return + } + // Create a link error containing the temporary channel failure and a // detail which indicates the we failed to add the htlc. linkError := NewDetailedLinkError( @@ -744,7 +751,8 @@ func (m *memoryMailBox) FailAdd(pkt *htlcPacket) { obfuscator: pkt.obfuscator, linkFailure: linkError, htlc: &lnwire.UpdateFailHTLC{ - Reason: reason, + Reason: reason, + ExtraData: extraData, }, } diff --git a/htlcswitch/mock.go b/htlcswitch/mock.go index 70bd73c37d2..c6b6150d1ab 100644 --- a/htlcswitch/mock.go +++ b/htlcswitch/mock.go @@ -341,11 +341,18 @@ func (r *mockHopIterator) ExtraOnionBlob() []byte { return nil } -func (r *mockHopIterator) ExtractErrorEncrypter( - extracter hop.ErrorEncrypterExtracter, _ bool) (hop.ErrorEncrypter, - lnwire.FailCode) { +func (r *mockHopIterator) ExtractEncrypterParams( + extracter hop.SharedSecretGenerator) (*btcec.PublicKey, sphinx.Hash256, + lnwire.BlindingPointRecord, lnwire.FailCode) { + + sharedSecret, failCode := extracter(nil) + if failCode != lnwire.CodeNone { + return nil, sphinx.Hash256{}, lnwire.BlindingPointRecord{}, + failCode + } - return extracter(nil) + return &btcec.PublicKey{}, sharedSecret, lnwire.BlindingPointRecord{}, + lnwire.CodeNone } func (r *mockHopIterator) EncodeNextHop(w io.Writer) error { @@ -412,16 +419,14 @@ func (o *mockObfuscator) Decode(r io.Reader) error { return nil } -func (o *mockObfuscator) Reextract( - extracter hop.ErrorEncrypterExtracter) error { - +func (o *mockObfuscator) Reextract(extracter hop.SharedSecretGenerator) error { return nil } var fakeHmac = []byte("hmachmachmachmachmachmachmachmac") func (o *mockObfuscator) EncryptFirstHop(failure lnwire.FailureMessage) ( - lnwire.OpaqueReason, error) { + lnwire.OpaqueReason, []byte, error) { o.failure = failure @@ -429,22 +434,27 @@ func (o *mockObfuscator) EncryptFirstHop(failure lnwire.FailureMessage) ( b.Write(fakeHmac) if err := lnwire.EncodeFailure(&b, failure, 0); err != nil { - return nil, err + return nil, nil, err } - return b.Bytes(), nil + + return b.Bytes(), nil, nil } -func (o *mockObfuscator) IntermediateEncrypt(reason lnwire.OpaqueReason) lnwire.OpaqueReason { - return reason +func (o *mockObfuscator) IntermediateEncrypt(reason lnwire.OpaqueReason, + attrData []byte) (lnwire.OpaqueReason, []byte, error) { + + return reason, nil, nil } -func (o *mockObfuscator) EncryptMalformedError(reason lnwire.OpaqueReason) lnwire.OpaqueReason { +func (o *mockObfuscator) EncryptMalformedError( + reason lnwire.OpaqueReason) (lnwire.OpaqueReason, []byte, error) { + var b bytes.Buffer b.Write(fakeHmac) b.Write(reason) - return b.Bytes() + return b.Bytes(), nil, nil } // mockDeobfuscator mock implementation of the failure deobfuscator which @@ -455,8 +465,8 @@ func newMockDeobfuscator() ErrorDecrypter { return &mockDeobfuscator{} } -func (o *mockDeobfuscator) DecryptError(reason lnwire.OpaqueReason) ( - *ForwardingError, error) { +func (o *mockDeobfuscator) DecryptError(reason lnwire.OpaqueReason, + attrData []byte) (*ForwardingError, error) { if !bytes.Equal(reason[:32], fakeHmac) { return nil, errors.New("fake decryption error") @@ -469,7 +479,7 @@ func (o *mockDeobfuscator) DecryptError(reason lnwire.OpaqueReason) ( return nil, err } - return NewForwardingError(failure, 1), nil + return NewForwardingError(failure, 1, nil), nil } var _ ErrorDecrypter = (*mockDeobfuscator)(nil) @@ -1133,21 +1143,6 @@ func (m *mockCircuitMap) NumOpen() int { return 0 } -type mockOnionErrorDecryptor struct { - sourceIdx int - message []byte - err error -} - -func (m *mockOnionErrorDecryptor) DecryptError(encryptedData []byte) ( - *sphinx.DecryptedError, error) { - - return &sphinx.DecryptedError{ - SenderIdx: m.sourceIdx, - Message: m.message, - }, m.err -} - var _ htlcNotifier = (*mockHTLCNotifier)(nil) type mockHTLCNotifier struct { diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index a3aae809b93..5de54ad986b 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -166,10 +166,10 @@ type Config struct { // forwarding packages, and ack settles and fails contained within them. SwitchPackager channeldb.FwdOperator - // ExtractErrorEncrypter is an interface allowing switch to reextract + // ExtractSharedSecret is an interface allowing switch to reextract // error encrypters stored in the circuit map on restarts, since they // are not stored directly within the database. - ExtractErrorEncrypter hop.ErrorEncrypterExtracter + ExtractSharedSecret hop.SharedSecretGenerator // FetchLastChannelUpdate retrieves the latest routing policy for a // target channel. This channel will typically be the outgoing channel @@ -361,11 +361,11 @@ func New(cfg Config, currentHeight uint32) (*Switch, error) { resStore := newResolutionStore(cfg.DB) circuitMap, err := NewCircuitMap(&CircuitMapConfig{ - DB: cfg.DB, - FetchAllOpenChannels: cfg.FetchAllOpenChannels, - FetchClosedChannels: cfg.FetchClosedChannels, - ExtractErrorEncrypter: cfg.ExtractErrorEncrypter, - CheckResolutionMsg: resStore.checkResolutionMsg, + DB: cfg.DB, + FetchAllOpenChannels: cfg.FetchAllOpenChannels, + FetchClosedChannels: cfg.FetchClosedChannels, + ExtractSharedSecret: cfg.ExtractSharedSecret, + CheckResolutionMsg: resStore.checkResolutionMsg, }) if err != nil { return nil, err @@ -1105,9 +1105,14 @@ func (s *Switch) parseFailedPayment(deobfuscator ErrorDecrypter, // A regular multi-hop payment error that we'll need to // decrypt. default: + attrData, err := lnwire.ExtraDataToAttrData(htlc.ExtraData) + if err != nil { + return err + } + // We'll attempt to fully decrypt the onion encrypted // error. If we're unable to then we'll bail early. - failure, err := deobfuscator.DecryptError(htlc.Reason) + failure, err := deobfuscator.DecryptError(htlc.Reason, attrData) if err != nil { log.Errorf("unable to de-obfuscate onion failure "+ "(hash=%v, pid=%d): %v", @@ -1232,7 +1237,9 @@ func (s *Switch) failAddPacket(packet *htlcPacket, failure *LinkError) error { // Encrypt the failure so that the sender will be able to read the error // message. Since we failed this packet, we use EncryptFirstHop to // obfuscate the failure for their eyes only. - reason, err := packet.obfuscator.EncryptFirstHop(failure.WireMessage()) + reason, attrData, err := packet.obfuscator.EncryptFirstHop( + failure.WireMessage(), + ) if err != nil { err := fmt.Errorf("unable to obfuscate "+ "error: %v", err) @@ -1242,6 +1249,11 @@ func (s *Switch) failAddPacket(packet *htlcPacket, failure *LinkError) error { log.Error(failure.Error()) + extraData, err := lnwire.AttrDataToExtraData(attrData) + if err != nil { + return err + } + // Create a failure packet for this htlc. The full set of // information about the htlc failure is included so that they can // be included in link failure notifications. @@ -1259,7 +1271,8 @@ func (s *Switch) failAddPacket(packet *htlcPacket, failure *LinkError) error { obfuscator: packet.obfuscator, linkFailure: failure, htlc: &lnwire.UpdateFailHTLC{ - Reason: reason, + Reason: reason, + ExtraData: extraData, }, } @@ -3163,7 +3176,7 @@ func (s *Switch) handlePacketFail(packet *htlcPacket, var err error // TODO(roasbeef): don't need to pass actually? failure := &lnwire.FailPermanentChannelFailure{} - htlc.Reason, err = circuit.ErrorEncrypter.EncryptFirstHop( + reason, attrData, err := circuit.ErrorEncrypter.EncryptFirstHop( failure, ) if err != nil { @@ -3171,6 +3184,12 @@ func (s *Switch) handlePacketFail(packet *htlcPacket, log.Error(err) } + htlc.Reason = reason + htlc.ExtraData, err = lnwire.AttrDataToExtraData(attrData) + if err != nil { + return err + } + // Alternatively, if the remote party sends us an // UpdateFailMalformedHTLC, then we'll need to convert this into a // proper well formatted onion error as there's no HMAC currently. @@ -3181,16 +3200,41 @@ func (s *Switch) handlePacketFail(packet *htlcPacket, packet.incomingChanID, packet.incomingHTLCID, packet.outgoingChanID, packet.outgoingHTLCID) - htlc.Reason = circuit.ErrorEncrypter.EncryptMalformedError( - htlc.Reason, - ) + reason, attrData, err := + circuit.ErrorEncrypter.EncryptMalformedError( + htlc.Reason, + ) + if err != nil { + return err + } + + htlc.Reason = reason + htlc.ExtraData, err = lnwire.AttrDataToExtraData(attrData) + if err != nil { + return err + } default: + attrData, err := lnwire.ExtraDataToAttrData(htlc.ExtraData) + if err != nil { + return err + } + // Otherwise, it's a forwarded error, so we'll perform a // wrapper encryption as normal. - htlc.Reason = circuit.ErrorEncrypter.IntermediateEncrypt( - htlc.Reason, - ) + reason, attrData, err := + circuit.ErrorEncrypter.IntermediateEncrypt( + htlc.Reason, attrData, + ) + if err != nil { + return err + } + + htlc.Reason = reason + htlc.ExtraData, err = lnwire.AttrDataToExtraData(attrData) + if err != nil { + return err + } } // Deliver this packet. diff --git a/htlcswitch/switch_test.go b/htlcswitch/switch_test.go index e8176aaeb59..541e16f4212 100644 --- a/htlcswitch/switch_test.go +++ b/htlcswitch/switch_test.go @@ -13,6 +13,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/davecgh/go-spew/spew" + sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/contractcourt" @@ -2743,7 +2744,7 @@ func TestSwitchSendPayment(t *testing.T) { // back. This request should be forwarded back to alice channel link. obfuscator := NewMockObfuscator() failure := lnwire.NewFailIncorrectDetails(update.Amount, 100) - reason, err := obfuscator.EncryptFirstHop(failure) + reason, _, err := obfuscator.EncryptFirstHop(failure) require.NoError(t, err, "unable obfuscate failure") if s.IsForwardedHTLC(aliceChannelLink.ShortChanID(), update.ID) { @@ -3234,9 +3235,9 @@ func TestInvalidFailure(t *testing.T) { // Get payment result from switch. We expect an unreadable failure // message error. deobfuscator := SphinxErrorDecrypter{ - OnionErrorDecrypter: &mockOnionErrorDecryptor{ - err: ErrUnreadableFailureMessage, - }, + decrypter: sphinx.NewOnionErrorDecrypter( + nil, hop.AttrErrorStruct, + ), } resultChan, err := s.GetAttemptResult( @@ -3255,43 +3256,6 @@ func TestInvalidFailure(t *testing.T) { case <-time.After(time.Second): t.Fatal("err wasn't received") } - - // Modify the decryption to simulate that decryption went alright, but - // the failure cannot be decoded. - deobfuscator = SphinxErrorDecrypter{ - OnionErrorDecrypter: &mockOnionErrorDecryptor{ - sourceIdx: 2, - message: []byte{200}, - }, - } - - resultChan, err = s.GetAttemptResult( - paymentID, rhash, &deobfuscator, - ) - if err != nil { - t.Fatal(err) - } - - select { - case result := <-resultChan: - rtErr, ok := result.Error.(ClearTextError) - if !ok { - t.Fatal("expected ClearTextError") - } - source, ok := rtErr.(*ForwardingError) - if !ok { - t.Fatalf("expected forwarding error, got: %T", rtErr) - } - if source.FailureSourceIdx != 2 { - t.Fatal("unexpected error source index") - } - if rtErr.WireMessage() != nil { - t.Fatal("expected empty failure message") - } - - case <-time.After(time.Second): - t.Fatal("err wasn't received") - } } // htlcNotifierEvents is a function that generates a set of expected htlc @@ -4069,7 +4033,9 @@ func TestSwitchHoldForward(t *testing.T) { OnionSHA256: shaOnionBlob, } - fwdErr, err := newMockDeobfuscator().DecryptError(failPacket.Reason) + fwdErr, err := newMockDeobfuscator().DecryptError( + failPacket.Reason, nil, + ) require.NoError(t, err) require.Equal(t, expectedFailure, fwdErr.WireMessage()) @@ -5535,7 +5501,7 @@ func testSwitchAliasInterceptFail(t *testing.T, zeroConf bool) { require.True(t, ok) fwdErr, err := newMockDeobfuscator().DecryptError( - failHtlc.Reason, + failHtlc.Reason, nil, ) require.NoError(t, err) diff --git a/htlcswitch/test_utils.go b/htlcswitch/test_utils.go index 2e084250943..d781b0e3421 100644 --- a/htlcswitch/test_utils.go +++ b/htlcswitch/test_utils.go @@ -1146,9 +1146,15 @@ func (h *hopNetwork) createChannelLink(server, peer *mockServer, Circuits: server.htlcSwitch.CircuitModifier(), ForwardPackets: forwardPackets, DecodeHopIterators: decoder.DecodeHopIterators, - ExtractErrorEncrypter: func(*btcec.PublicKey) ( - hop.ErrorEncrypter, lnwire.FailCode) { - return h.obfuscator, lnwire.CodeNone + ExtractSharedSecret: func(*btcec.PublicKey) ( + sphinx.Hash256, lnwire.FailCode) { + + return sphinx.Hash256{}, lnwire.CodeNone + }, + CreateErrorEncrypter: func(*btcec.PublicKey, + sphinx.Hash256, bool, bool) hop.ErrorEncrypter { + + return h.obfuscator }, FetchLastChannelUpdate: mockGetChanUpdateMessage, Registry: server.registry, diff --git a/itest/list_on_test.go b/itest/list_on_test.go index a49efa477b3..51003ff805e 100644 --- a/itest/list_on_test.go +++ b/itest/list_on_test.go @@ -715,6 +715,10 @@ var allTestCases = []*lntest.TestCase{ Name: "experimental accountability", TestFunc: testExperimentalAccountability, }, + { + Name: "attributable failure hold times", + TestFunc: testAttributableFailureHoldTimes, + }, { Name: "quiescence", TestFunc: testQuiescence, diff --git a/itest/lnd_attributable_failure_test.go b/itest/lnd_attributable_failure_test.go new file mode 100644 index 00000000000..edb01477eb0 --- /dev/null +++ b/itest/lnd_attributable_failure_test.go @@ -0,0 +1,103 @@ +package itest + +import ( + "github.com/lightningnetwork/lnd/funding" + "github.com/lightningnetwork/lnd/lnrpc" + "github.com/lightningnetwork/lnd/lnrpc/routerrpc" + "github.com/lightningnetwork/lnd/lntest" + "github.com/stretchr/testify/require" +) + +// testAttributableFailureHoldTimes verifies that when a payment fails at a +// downstream node, the sender receives hold times in the failure response via +// attributable errors. It sets up Alice -> Bob -> Carol, triggers a failure at +// Carol (unknown payment hash), and checks that Alice's payment result includes +// hold times that correctly correspond to the route hops. +func testAttributableFailureHoldTimes(ht *lntest.HarnessTest) { + const chanAmt = funding.MaxBtcFundingAmount + + alice := ht.NewNodeWithCoins("Alice", nil) + bob := ht.NewNodeWithCoins("Bob", nil) + carol := ht.NewNode("Carol", nil) + + ht.EnsureConnected(alice, bob) + ht.ConnectNodes(bob, carol) + + // Open channels: Alice -> Bob -> Carol. + chanPointAB := ht.OpenChannel( + alice, bob, + lntest.OpenChannelParams{Amt: chanAmt}, + ) + chanPointBC := ht.OpenChannel( + bob, carol, + lntest.OpenChannelParams{Amt: chanAmt}, + ) + + // Wait for Alice to see both channels. + ht.AssertChannelInGraph(alice, chanPointAB) + ht.AssertChannelInGraph(alice, chanPointBC) + + // Create an invoice from Carol to get valid route parameters. + carolInvoice := carol.RPC.AddInvoice(&lnrpc.Invoice{ + Memo: "hold-time-test", + Value: 10_000, + }) + carolPayReq := carol.RPC.DecodePayReq(carolInvoice.PaymentRequest) + + // Send a payment with a random (wrong) payment hash so Carol rejects + // it with INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS. This ensures the error + // originates at Carol (the final hop) and propagates back through Bob + // to Alice, with each hop adding its hold time. + sendReq := &routerrpc.SendPaymentRequest{ + PaymentHash: ht.Random32Bytes(), + Dest: carol.PubKey[:], + Amt: 10_000, + FinalCltvDelta: int32(carolPayReq.CltvExpiry), + FeeLimitMsat: noFeeLimitMsat, + MaxParts: 1, + } + + failReason := lnrpc.PaymentFailureReason_FAILURE_REASON_INCORRECT_PAYMENT_DETAILS //nolint:ll + payment := ht.SendPaymentAssertFail( + alice, sendReq, failReason, + ) + + // Verify we got at least one HTLC attempt with failure info. + require.NotEmpty(ht, payment.Htlcs, "expected at least one HTLC") + htlcAttempt := payment.Htlcs[len(payment.Htlcs)-1] + require.NotNil(ht, htlcAttempt.Failure, "expected failure info") + + // The route should have 2 hops: Alice->Bob->Carol. + require.Len(ht, htlcAttempt.Route.Hops, 2, + "expected 2-hop route (Bob, Carol)") + + // Verify the failure code and source. Carol is at index 2 (0=Alice, + // 1=Bob, 2=Carol). + require.Equal(ht, + lnrpc.Failure_INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, + htlcAttempt.Failure.Code, + ) + require.EqualValues(ht, 2, htlcAttempt.Failure.FailureSourceIndex, + "failure should originate from Carol (index 2)") + + // Verify hold times. With attributable errors, we expect one hold time + // entry per hop in the route. hold_times[0] corresponds to + // route.hops[0] (Bob) and hold_times[1] to route.hops[1] (Carol). + holdTimes := htlcAttempt.Failure.HoldTimes + require.Len(ht, holdTimes, len(htlcAttempt.Route.Hops), + "hold_times should have one entry per route hop") + + // Hold times are in 100ms units. In a test environment with local + // nodes, processing should be nearly instant (likely 0-2 units). + // We verify each value is within a reasonable upper bound (< 10s) + // to catch any corruption or mis-encoding. + const maxReasonableHoldTime = uint32(100) // 10 seconds + for i, holdTime := range holdTimes { + require.LessOrEqual(ht, holdTime, maxReasonableHoldTime, + "hold time for hop %d (%s) unreasonably large: "+ + "%d (= %dms)", + i, htlcAttempt.Route.Hops[i].PubKey, + holdTime, holdTime*100, + ) + } +} diff --git a/lnrpc/lightning.pb.go b/lnrpc/lightning.pb.go index 5f97bbd2d42..fd5573d52c9 100644 --- a/lnrpc/lightning.pb.go +++ b/lnrpc/lightning.pb.go @@ -16978,7 +16978,13 @@ type Failure struct { // the failure message. Position zero is the sender node. FailureSourceIndex uint32 `protobuf:"varint,8,opt,name=failure_source_index,json=failureSourceIndex,proto3" json:"failure_source_index,omitempty"` // A failure type-dependent block height. - Height uint32 `protobuf:"varint,9,opt,name=height,proto3" json:"height,omitempty"` + Height uint32 `protobuf:"varint,9,opt,name=height,proto3" json:"height,omitempty"` + // An array of hold times (in 100ms units) as reported by the nodes along + // the route via attributable errors. The first element corresponds to the + // first hop after the sender, with greater indices indicating nodes + // further along the route. Multiply by 100 to get milliseconds. This + // field is only populated when the error includes attribution data. + HoldTimes []uint32 `protobuf:"varint,10,rep,packed,name=hold_times,json=holdTimes,proto3" json:"hold_times,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -17069,6 +17075,13 @@ func (x *Failure) GetHeight() uint32 { return 0 } +func (x *Failure) GetHoldTimes() []uint32 { + if x != nil { + return x.HoldTimes + } + return nil +} + type ChannelUpdate struct { state protoimpl.MessageState `protogen:"open.v1"` // The signature that validates the announced data and proves the ownership @@ -20115,7 +20128,7 @@ const file_lightning_proto_rawDesc = "" + "\x12method_permissions\x18\x01 \x03(\v25.lnrpc.ListPermissionsResponse.MethodPermissionsEntryR\x11methodPermissions\x1ac\n" + "\x16MethodPermissionsEntry\x12\x10\n" + "\x03key\x18\x01 \x01(\tR\x03key\x123\n" + - "\x05value\x18\x02 \x01(\v2\x1d.lnrpc.MacaroonPermissionListR\x05value:\x028\x01\"\xcc\b\n" + + "\x05value\x18\x02 \x01(\v2\x1d.lnrpc.MacaroonPermissionListR\x05value:\x028\x01\"\xeb\b\n" + "\aFailure\x12.\n" + "\x04code\x18\x01 \x01(\x0e2\x1a.lnrpc.Failure.FailureCodeR\x04code\x12;\n" + "\x0echannel_update\x18\x03 \x01(\v2\x14.lnrpc.ChannelUpdateR\rchannelUpdate\x12\x1b\n" + @@ -20125,7 +20138,10 @@ const file_lightning_proto_rawDesc = "" + "cltvExpiry\x12\x14\n" + "\x05flags\x18\a \x01(\rR\x05flags\x120\n" + "\x14failure_source_index\x18\b \x01(\rR\x12failureSourceIndex\x12\x16\n" + - "\x06height\x18\t \x01(\rR\x06height\"\x8b\x06\n" + + "\x06height\x18\t \x01(\rR\x06height\x12\x1d\n" + + "\n" + + "hold_times\x18\n" + + " \x03(\rR\tholdTimes\"\x8b\x06\n" + "\vFailureCode\x12\f\n" + "\bRESERVED\x10\x00\x12(\n" + "$INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS\x10\x01\x12\x1c\n" + diff --git a/lnrpc/lightning.proto b/lnrpc/lightning.proto index 66dedd431e8..cbcf19fb1dc 100644 --- a/lnrpc/lightning.proto +++ b/lnrpc/lightning.proto @@ -5142,6 +5142,15 @@ message Failure { // A failure type-dependent block height. uint32 height = 9; + + /* + An array of hold times (in 100ms units) as reported by the nodes along + the route via attributable errors. The first element corresponds to the + first hop after the sender, with greater indices indicating nodes + further along the route. Multiply by 100 to get milliseconds. This + field is only populated when the error includes attribution data. + */ + repeated uint32 hold_times = 10; } message ChannelUpdate { diff --git a/lnrpc/lightning.swagger.json b/lnrpc/lightning.swagger.json index bc16701b04a..3a1897fbbf9 100644 --- a/lnrpc/lightning.swagger.json +++ b/lnrpc/lightning.swagger.json @@ -5127,6 +5127,14 @@ "type": "integer", "format": "int64", "description": "A failure type-dependent block height." + }, + "hold_times": { + "type": "array", + "items": { + "type": "integer", + "format": "int64" + }, + "description": "An array of hold times (in 100ms units) as reported by the nodes along\nthe route via attributable errors. The first element corresponds to the\nfirst hop after the sender, with greater indices indicating nodes\nfurther along the route. Multiply by 100 to get milliseconds. This\nfield is only populated when the error includes attribution data." } } }, diff --git a/lnrpc/routerrpc/router.swagger.json b/lnrpc/routerrpc/router.swagger.json index 59aa426dd1b..1295c45c6cc 100644 --- a/lnrpc/routerrpc/router.swagger.json +++ b/lnrpc/routerrpc/router.swagger.json @@ -860,6 +860,14 @@ "type": "integer", "format": "int64", "description": "A failure type-dependent block height." + }, + "hold_times": { + "type": "array", + "items": { + "type": "integer", + "format": "int64" + }, + "description": "An array of hold times (in 100ms units) as reported by the nodes along\nthe route via attributable errors. The first element corresponds to the\nfirst hop after the sender, with greater indices indicating nodes\nfurther along the route. Multiply by 100 to get milliseconds. This\nfield is only populated when the error includes attribution data." } } }, diff --git a/lnrpc/routerrpc/router_backend.go b/lnrpc/routerrpc/router_backend.go index 3085f587dc1..5bbd5e16a7b 100644 --- a/lnrpc/routerrpc/router_backend.go +++ b/lnrpc/routerrpc/router_backend.go @@ -1546,6 +1546,7 @@ func marshallHtlcFailure(failure *paymentsdb.HTLCFailInfo) (*lnrpc.Failure, rpcFailure := &lnrpc.Failure{ FailureSourceIndex: failure.FailureSourceIndex, + HoldTimes: failure.HoldTimes, } switch failure.Reason { @@ -1614,6 +1615,7 @@ func marshallError(sendError error) (*lnrpc.Failure, error) { fErr, ok := rtErr.(*htlcswitch.ForwardingError) if ok { response.FailureSourceIndex = uint32(fErr.FailureSourceIdx) + response.HoldTimes = fErr.HoldTimes } return response, nil diff --git a/lnwallet/channel.go b/lnwallet/channel.go index 02f5f9ccff4..8684093cb1c 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -1188,13 +1188,14 @@ func (lc *LightningChannel) logUpdateToPayDesc(logUpdate *channeldb.LogUpdate, ogHTLC := remoteUpdateLog.lookupHtlc(wireMsg.ID) pd = &paymentDescriptor{ - ChanID: wireMsg.ChanID, - Amount: ogHTLC.Amount, - RHash: ogHTLC.RHash, - ParentIndex: ogHTLC.HtlcIndex, - LogIndex: logUpdate.LogIndex, - EntryType: Fail, - FailReason: wireMsg.Reason[:], + ChanID: wireMsg.ChanID, + Amount: ogHTLC.Amount, + RHash: ogHTLC.RHash, + ParentIndex: ogHTLC.HtlcIndex, + LogIndex: logUpdate.LogIndex, + EntryType: Fail, + FailReason: wireMsg.Reason[:], + FailExtraData: wireMsg.ExtraData, removeCommitHeights: lntypes.Dual[uint64]{ Remote: commitHeight, }, @@ -1288,13 +1289,14 @@ func (lc *LightningChannel) localLogUpdateToPayDesc(logUpdate *channeldb.LogUpda ogHTLC := remoteUpdateLog.lookupHtlc(wireMsg.ID) return &paymentDescriptor{ - ChanID: wireMsg.ChanID, - Amount: ogHTLC.Amount, - RHash: ogHTLC.RHash, - ParentIndex: ogHTLC.HtlcIndex, - LogIndex: logUpdate.LogIndex, - EntryType: Fail, - FailReason: wireMsg.Reason[:], + ChanID: wireMsg.ChanID, + Amount: ogHTLC.Amount, + RHash: ogHTLC.RHash, + ParentIndex: ogHTLC.HtlcIndex, + LogIndex: logUpdate.LogIndex, + EntryType: Fail, + FailReason: wireMsg.Reason[:], + FailExtraData: wireMsg.ExtraData, removeCommitHeights: lntypes.Dual[uint64]{ Remote: commitHeight, }, @@ -1407,13 +1409,14 @@ func (lc *LightningChannel) remoteLogUpdateToPayDesc(logUpdate *channeldb.LogUpd ogHTLC := localUpdateLog.lookupHtlc(wireMsg.ID) return &paymentDescriptor{ - ChanID: wireMsg.ChanID, - Amount: ogHTLC.Amount, - RHash: ogHTLC.RHash, - ParentIndex: ogHTLC.HtlcIndex, - LogIndex: logUpdate.LogIndex, - EntryType: Fail, - FailReason: wireMsg.Reason[:], + ChanID: wireMsg.ChanID, + Amount: ogHTLC.Amount, + RHash: ogHTLC.RHash, + ParentIndex: ogHTLC.HtlcIndex, + LogIndex: logUpdate.LogIndex, + EntryType: Fail, + FailReason: wireMsg.Reason[:], + FailExtraData: wireMsg.ExtraData, removeCommitHeights: lntypes.Dual[uint64]{ Local: commitHeight, }, @@ -6524,7 +6527,8 @@ func (lc *LightningChannel) ReceiveHTLCSettle(preimage [32]byte, htlcIndex uint6 // NOTE: It is okay for sourceRef, destRef, and closeKey to be nil when unit // testing the wallet. func (lc *LightningChannel) FailHTLC(htlcIndex uint64, reason []byte, - sourceRef *channeldb.AddRef, destRef *channeldb.SettleFailRef, + extraData lnwire.ExtraOpaqueData, sourceRef *channeldb.AddRef, + destRef *channeldb.SettleFailRef, closeKey *models.CircuitKey) error { lc.Lock() @@ -6552,6 +6556,7 @@ func (lc *LightningChannel) FailHTLC(htlcIndex uint64, reason []byte, SourceRef: sourceRef, DestRef: destRef, ClosedCircuitKey: closeKey, + FailExtraData: extraData, } lc.updateLogs.Local.appendUpdate(pd) @@ -6619,7 +6624,7 @@ func (lc *LightningChannel) MalformedFailHTLC(htlcIndex uint64, // commitment update. This method should be called in response to the upstream // party cancelling an outgoing HTLC. func (lc *LightningChannel) ReceiveFailHTLC(htlcIndex uint64, reason []byte, -) error { + extraData lnwire.ExtraOpaqueData) error { lc.Lock() defer lc.Unlock() @@ -6636,13 +6641,14 @@ func (lc *LightningChannel) ReceiveFailHTLC(htlcIndex uint64, reason []byte, } pd := &paymentDescriptor{ - ChanID: lc.ChannelID(), - Amount: htlc.Amount, - RHash: htlc.RHash, - ParentIndex: htlc.HtlcIndex, - LogIndex: lc.updateLogs.Remote.logIndex, - EntryType: Fail, - FailReason: reason, + ChanID: lc.ChannelID(), + Amount: htlc.Amount, + RHash: htlc.RHash, + ParentIndex: htlc.HtlcIndex, + LogIndex: lc.updateLogs.Remote.logIndex, + EntryType: Fail, + FailReason: reason, + FailExtraData: extraData, } lc.updateLogs.Remote.appendUpdate(pd) diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go index 6e175ba7392..c07100488a6 100644 --- a/lnwallet/channel_test.go +++ b/lnwallet/channel_test.go @@ -466,9 +466,9 @@ func TestChannelZeroAddLocalHeight(t *testing.T) { // Now Bob should fail the htlc back to Alice. // <----fail----- - err = bobChannel.FailHTLC(0, []byte("failreason"), nil, nil, nil) + err = bobChannel.FailHTLC(0, []byte("failreason"), nil, nil, nil, nil) require.NoError(t, err) - err = aliceChannel.ReceiveFailHTLC(0, []byte("bad")) + err = aliceChannel.ReceiveFailHTLC(0, []byte("bad"), nil) require.NoError(t, err) // Bob should send a commitment signature to Alice. @@ -2222,9 +2222,11 @@ func TestCancelHTLC(t *testing.T) { // Now, with the HTLC committed on both sides, trigger a cancellation // from Bob to Alice, removing the HTLC. - err = bobChannel.FailHTLC(bobHtlcIndex, []byte("failreason"), nil, nil, nil) + err = bobChannel.FailHTLC( + bobHtlcIndex, []byte("failreason"), nil, nil, nil, nil, + ) require.NoError(t, err, "unable to cancel HTLC") - err = aliceChannel.ReceiveFailHTLC(aliceHtlcIndex, []byte("bad")) + err = aliceChannel.ReceiveFailHTLC(aliceHtlcIndex, []byte("bad"), nil) require.NoError(t, err, "unable to recv htlc cancel") // Now trigger another state transition, the HTLC should now be removed @@ -5509,9 +5511,9 @@ func TestChanAvailableBandwidth(t *testing.T) { } htlcIndex := uint64((numHtlcs * 2) - 1) - err = bobChannel.FailHTLC(htlcIndex, []byte("f"), nil, nil, nil) + err = bobChannel.FailHTLC(htlcIndex, []byte("f"), nil, nil, nil, nil) require.NoError(t, err, "unable to cancel HTLC") - err = aliceChannel.ReceiveFailHTLC(htlcIndex, []byte("bad")) + err = aliceChannel.ReceiveFailHTLC(htlcIndex, []byte("bad"), nil) require.NoError(t, err, "unable to recv htlc cancel") // We must do a state transition before the balance is available @@ -5965,9 +5967,11 @@ func TestLockedInHtlcForwardingSkipAfterRestart(t *testing.T) { // With both nodes restarted, Bob will now attempt to cancel one of // Alice's HTLC's. - err = bobChannel.FailHTLC(htlc.ID, []byte("failreason"), nil, nil, nil) + err = bobChannel.FailHTLC( + htlc.ID, []byte("failreason"), nil, nil, nil, nil, + ) require.NoError(t, err, "unable to cancel HTLC") - err = aliceChannel.ReceiveFailHTLC(htlc.ID, []byte("bad")) + err = aliceChannel.ReceiveFailHTLC(htlc.ID, []byte("bad"), nil) require.NoError(t, err, "unable to recv htlc cancel") // We'll now initiate another state transition, but this time Bob will @@ -6018,9 +6022,11 @@ func TestLockedInHtlcForwardingSkipAfterRestart(t *testing.T) { // Failing the HTLC here will cause the update to be included in Alice's // remote log, but it should not be committed by this transition. - err = bobChannel.FailHTLC(htlc2.ID, []byte("failreason"), nil, nil, nil) + err = bobChannel.FailHTLC( + htlc2.ID, []byte("failreason"), nil, nil, nil, nil, + ) require.NoError(t, err, "unable to cancel HTLC") - err = aliceChannel.ReceiveFailHTLC(htlc2.ID, []byte("bad")) + err = aliceChannel.ReceiveFailHTLC(htlc2.ID, []byte("bad"), nil) require.NoError(t, err, "unable to recv htlc cancel") bobRevocation, _, finalHtlcs, err := bobChannel. @@ -6073,9 +6079,11 @@ func TestLockedInHtlcForwardingSkipAfterRestart(t *testing.T) { // Re-add the Fail to both Alice and Bob's channels, as the non-committed // update will not have survived the restart. - err = bobChannel.FailHTLC(htlc2.ID, []byte("failreason"), nil, nil, nil) + err = bobChannel.FailHTLC( + htlc2.ID, []byte("failreason"), nil, nil, nil, nil, + ) require.NoError(t, err, "unable to cancel HTLC") - err = aliceChannel.ReceiveFailHTLC(htlc2.ID, []byte("bad")) + err = aliceChannel.ReceiveFailHTLC(htlc2.ID, []byte("bad"), nil) require.NoError(t, err, "unable to recv htlc cancel") // Have Alice initiate a state transition, which does not include the @@ -6520,9 +6528,14 @@ func TestDesyncHTLCs(t *testing.T) { } // Now let Bob fail this HTLC. - err = bobChannel.FailHTLC(bobIndex, []byte("failreason"), nil, nil, nil) + err = bobChannel.FailHTLC( + bobIndex, []byte("failreason"), nil, nil, nil, nil, + ) require.NoError(t, err, "unable to cancel HTLC") - if err := aliceChannel.ReceiveFailHTLC(aliceIndex, []byte("bad")); err != nil { + err = aliceChannel.ReceiveFailHTLC( + aliceIndex, []byte("bad"), nil, + ) + if err != nil { t.Fatalf("unable to recv htlc cancel: %v", err) } @@ -6612,10 +6625,11 @@ func TestMaxAcceptedHTLCs(t *testing.T) { // Bob will fail the htlc specified by htlcID and then force a state // transition. - err = bobChannel.FailHTLC(htlcID, []byte{}, nil, nil, nil) + err = bobChannel.FailHTLC(htlcID, []byte{}, nil, nil, nil, nil) require.NoError(t, err, "unable to fail htlc") - if err := aliceChannel.ReceiveFailHTLC(htlcID, []byte{}); err != nil { + err = aliceChannel.ReceiveFailHTLC(htlcID, []byte{}, nil) + if err != nil { t.Fatalf("unable to receive fail htlc: %v", err) } @@ -6718,10 +6732,11 @@ func TestMaxAsynchronousHtlcs(t *testing.T) { addAndReceiveHTLC(t, aliceChannel, bobChannel, htlc, nil) // Fail back an HTLC and sign a commitment as in steps 1 & 2. - err = bobChannel.FailHTLC(htlcID, []byte{}, nil, nil, nil) + err = bobChannel.FailHTLC(htlcID, []byte{}, nil, nil, nil, nil) require.NoError(t, err, "unable to fail htlc") - if err := aliceChannel.ReceiveFailHTLC(htlcID, []byte{}); err != nil { + err = aliceChannel.ReceiveFailHTLC(htlcID, []byte{}, nil) + if err != nil { t.Fatalf("unable to receive fail htlc: %v", err) } @@ -7546,10 +7561,10 @@ func TestChannelRestoreUpdateLogsFailedHTLC(t *testing.T) { restoreAndAssert(t, aliceChannel, 1, 0, 0, 0) // Now we make Bob fail this HTLC. - err = bobChannel.FailHTLC(0, []byte("failreason"), nil, nil, nil) + err = bobChannel.FailHTLC(0, []byte("failreason"), nil, nil, nil, nil) require.NoError(t, err, "unable to cancel HTLC") - err = aliceChannel.ReceiveFailHTLC(0, []byte("failreason")) + err = aliceChannel.ReceiveFailHTLC(0, []byte("failreason"), nil) require.NoError(t, err, "unable to recv htlc cancel") // This Fail update should have been added to Alice's remote update log. @@ -7632,19 +7647,22 @@ func TestDuplicateFailRejection(t *testing.T) { // With the HTLC locked in, we'll now have Bob fail the HTLC back to // Alice. - err = bobChannel.FailHTLC(0, []byte("failreason"), nil, nil, nil) + err = bobChannel.FailHTLC(0, []byte("failreason"), nil, nil, nil, nil) require.NoError(t, err, "unable to cancel HTLC") - if err := aliceChannel.ReceiveFailHTLC(0, []byte("bad")); err != nil { + err = aliceChannel.ReceiveFailHTLC(0, []byte("bad"), nil) + if err != nil { t.Fatalf("unable to recv htlc cancel: %v", err) } // If we attempt to fail it AGAIN, then both sides should reject this // second failure attempt. - err = bobChannel.FailHTLC(0, []byte("failreason"), nil, nil, nil) + err = bobChannel.FailHTLC(0, []byte("failreason"), nil, nil, nil, nil) if err == nil { t.Fatalf("duplicate HTLC failure attempt should have failed") } - if err := aliceChannel.ReceiveFailHTLC(0, []byte("bad")); err == nil { + + err = aliceChannel.ReceiveFailHTLC(0, []byte("bad"), nil) + if err == nil { t.Fatalf("duplicate HTLC failure attempt should have failed") } @@ -7661,14 +7679,15 @@ func TestDuplicateFailRejection(t *testing.T) { require.NoError(t, err, "unable to restart channel") // If we try to fail the same HTLC again, then we should get an error. - err = bobChannel.FailHTLC(0, []byte("failreason"), nil, nil, nil) + err = bobChannel.FailHTLC(0, []byte("failreason"), nil, nil, nil, nil) if err == nil { t.Fatalf("duplicate HTLC failure attempt should have failed") } // Alice on the other hand should accept the failure again, as she // dropped all items in the logs which weren't committed. - if err := aliceChannel.ReceiveFailHTLC(0, []byte("bad")); err != nil { + err = aliceChannel.ReceiveFailHTLC(0, []byte("bad"), nil) + if err != nil { t.Fatalf("unable to recv htlc cancel: %v", err) } } @@ -7929,9 +7948,9 @@ func TestChannelRestoreCommitHeight(t *testing.T) { bobChannel = restoreAndAssertCommitHeights(t, bobChannel, true, 1, 2, 2) // Bob now fails back the htlc that was just locked in. - err = bobChannel.FailHTLC(0, []byte("failreason"), nil, nil, nil) + err = bobChannel.FailHTLC(0, []byte("failreason"), nil, nil, nil, nil) require.NoError(t, err, "unable to cancel HTLC") - err = aliceChannel.ReceiveFailHTLC(0, []byte("bad")) + err = aliceChannel.ReceiveFailHTLC(0, []byte("bad"), nil) require.NoError(t, err, "unable to recv htlc cancel") // Now Bob signs for the fail update. @@ -9252,9 +9271,9 @@ func TestChannelUnsignedAckedFailure(t *testing.T) { // Now Bob should fail the htlc back to Alice. // <----fail----- - err = bobChannel.FailHTLC(0, []byte("failreason"), nil, nil, nil) + err = bobChannel.FailHTLC(0, []byte("failreason"), nil, nil, nil, nil) require.NoError(t, err) - err = aliceChannel.ReceiveFailHTLC(0, []byte("bad")) + err = aliceChannel.ReceiveFailHTLC(0, []byte("bad"), nil) require.NoError(t, err) // Bob should send a commitment signature to Alice. @@ -9356,9 +9375,11 @@ func TestChannelLocalUnsignedUpdatesFailure(t *testing.T) { // Now Alice should fail the htlc back to Bob. // -----fail---> - err = aliceChannel.FailHTLC(0, []byte("failreason"), nil, nil, nil) + err = aliceChannel.FailHTLC( + 0, []byte("failreason"), nil, nil, nil, nil, + ) require.NoError(t, err) - err = bobChannel.ReceiveFailHTLC(0, []byte("bad")) + err = bobChannel.ReceiveFailHTLC(0, []byte("bad"), nil) require.NoError(t, err) // Alice should send a commitment signature to Bob. @@ -10801,10 +10822,10 @@ func TestAsynchronousSendingWithFeeBuffer(t *testing.T) { // <----rev------- |--------------- // <----sig------- |--------------- // --------------- |-----rev------> - err = aliceChannel.FailHTLC(0, []byte{}, nil, nil, nil) + err = aliceChannel.FailHTLC(0, []byte{}, nil, nil, nil, nil) require.NoError(t, err) - err = bobChannel.ReceiveFailHTLC(0, []byte{}) + err = bobChannel.ReceiveFailHTLC(0, []byte{}, nil) require.NoError(t, err) err = ForceStateTransition(aliceChannel, bobChannel) diff --git a/lnwallet/payment_descriptor.go b/lnwallet/payment_descriptor.go index 944749bde9f..ac8f6a27f03 100644 --- a/lnwallet/payment_descriptor.go +++ b/lnwallet/payment_descriptor.go @@ -246,6 +246,10 @@ type paymentDescriptor struct { // CustomRecords also stores the set of optional custom records that // may have been attached to a sent HTLC. CustomRecords lnwire.CustomRecords + + // FailExtraData stores any extra opaque data that may have been present + // when receiving an UpdateFailHTLC message. + FailExtraData lnwire.ExtraOpaqueData } // toLogUpdate recovers the underlying LogUpdate from the paymentDescriptor. @@ -274,9 +278,10 @@ func (pd *paymentDescriptor) toLogUpdate() channeldb.LogUpdate { } case Fail: msg = &lnwire.UpdateFailHTLC{ - ChanID: pd.ChanID, - ID: pd.ParentIndex, - Reason: pd.FailReason, + ChanID: pd.ChanID, + ID: pd.ParentIndex, + Reason: pd.FailReason, + ExtraData: pd.FailExtraData, } case MalformedFail: msg = &lnwire.UpdateFailMalformedHTLC{ diff --git a/lnwire/attr_data.go b/lnwire/attr_data.go new file mode 100644 index 00000000000..65a6559bfc4 --- /dev/null +++ b/lnwire/attr_data.go @@ -0,0 +1,27 @@ +package lnwire + +import "github.com/lightningnetwork/lnd/tlv" + +// AttrDataTlvType is the TlvType that hosts the attribution data in the +// update_fail_htlc wire message. +var AttrDataTlvType tlv.TlvType101 + +// AttrDataToExtraData converts the provided attribution data to the extra +// opaque data to be included in the wire message. +func AttrDataToExtraData(attrData []byte) (ExtraOpaqueData, error) { + attrRecs := make(tlv.TypeMap) + attrRecs[AttrDataTlvType.TypeVal()] = attrData + + return NewExtraOpaqueData(attrRecs) +} + +// ExtraDataToAttrData takes the extra opaque data of the wire message and tries +// to extract the attribution data. +func ExtraDataToAttrData(extraData ExtraOpaqueData) ([]byte, error) { + extraRecords, err := extraData.ExtractRecords() + if err != nil { + return nil, err + } + + return extraRecords[AttrDataTlvType.TypeVal()], nil +} diff --git a/lnwire/attr_data_test.go b/lnwire/attr_data_test.go new file mode 100644 index 00000000000..b7bf53a92ca --- /dev/null +++ b/lnwire/attr_data_test.go @@ -0,0 +1,104 @@ +package lnwire + +import ( + "testing" + + "github.com/lightningnetwork/lnd/tlv" + "github.com/stretchr/testify/require" +) + +// TestAttrDataRoundTrip tests that attribution data survives a round-trip +// through AttrDataToExtraData and ExtraDataToAttrData. +func TestAttrDataRoundTrip(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + attrData []byte + }{ + { + name: "nil attribution data", + attrData: nil, + }, + { + name: "empty attribution data", + attrData: []byte{}, + }, + { + name: "small attribution data", + attrData: []byte{0x01, 0x02, 0x03}, + }, + { + name: "realistic size attribution data", + attrData: make([]byte, 1200), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + extraData, err := AttrDataToExtraData(tc.attrData) + require.NoError(t, err) + + recovered, err := ExtraDataToAttrData(extraData) + require.NoError(t, err) + + // Both nil and empty should round-trip to equivalent + // "no data" values. + if len(tc.attrData) == 0 { + require.Empty(t, recovered) + return + } + + require.Equal(t, tc.attrData, recovered) + }) + } +} + +// TestExtraDataToAttrDataNoRecord tests that ExtraDataToAttrData returns nil +// when the extra data does not contain an attribution record. +func TestExtraDataToAttrDataNoRecord(t *testing.T) { + t.Parallel() + + // Build extra data with a different TLV type (not 101). + otherType := tlv.Type(200) + records := make(tlv.TypeMap) + records[otherType] = []byte{0xaa, 0xbb} + extraData, err := NewExtraOpaqueData(records) + require.NoError(t, err) + + attrData, err := ExtraDataToAttrData(extraData) + require.NoError(t, err) + require.Nil(t, attrData) +} + +// TestAttrDataPreservesOtherRecords verifies that encoding attribution data +// into ExtraOpaqueData produces a valid TLV stream with the correct type. +func TestAttrDataTlvType(t *testing.T) { + t.Parallel() + + payload := []byte{0xde, 0xad, 0xbe, 0xef} + extraData, err := AttrDataToExtraData(payload) + require.NoError(t, err) + + // Extract all records and verify the attribution type is present. + records, err := extraData.ExtractRecords() + require.NoError(t, err) + + value, ok := records[AttrDataTlvType.TypeVal()] + require.True(t, ok, "expected attribution TLV type %d in records", + AttrDataTlvType.TypeVal()) + require.Equal(t, payload, value) +} + +// TestExtraDataToAttrDataEmpty tests that an empty ExtraOpaqueData returns nil +// attribution data without error. +func TestExtraDataToAttrDataEmpty(t *testing.T) { + t.Parallel() + + var empty ExtraOpaqueData + attrData, err := ExtraDataToAttrData(empty) + require.NoError(t, err) + require.Nil(t, attrData) +} diff --git a/payments/db/kv_store.go b/payments/db/kv_store.go index b3fffe1e182..e7b03edfcc7 100644 --- a/payments/db/kv_store.go +++ b/payments/db/kv_store.go @@ -2071,7 +2071,25 @@ func serializeHTLCFailInfo(w io.Writer, f *HTLCFailInfo) error { return err } - return WriteElements(w, byte(f.Reason), f.FailureSourceIndex) + err := WriteElements(w, byte(f.Reason), f.FailureSourceIndex) + if err != nil { + return err + } + + // Write hold times count followed by each value. This is appended + // after the original fields for backward compatibility — old readers + // will simply stop reading at the end of FailureSourceIndex. + numHoldTimes := uint16(len(f.HoldTimes)) + if err := WriteElements(w, numHoldTimes); err != nil { + return err + } + for _, ht := range f.HoldTimes { + if err := WriteElements(w, ht); err != nil { + return err + } + } + + return nil } // deserializeHTLCFailInfo deserializes the details of a failed htlc including @@ -2117,5 +2135,28 @@ func deserializeHTLCFailInfo(r io.Reader) (*HTLCFailInfo, error) { } f.Reason = HTLCFailReason(reason) + // Read hold times if present. Old data won't have this field, so we + // treat EOF as "no hold times". + var numHoldTimes uint16 + if err := ReadElements(r, &numHoldTimes); err != nil { + // If there's no more data, this is old format — return + // without hold times. + if errors.Is(err, io.EOF) || + errors.Is(err, io.ErrUnexpectedEOF) { + + return f, nil + } + + return nil, err + } + if numHoldTimes > 0 { + f.HoldTimes = make([]uint32, numHoldTimes) + for i := range f.HoldTimes { + if err := ReadElements(r, &f.HoldTimes[i]); err != nil { + return nil, err + } + } + } + return f, nil } diff --git a/payments/db/payment.go b/payments/db/payment.go index ddceedfb0f0..129de8c0477 100644 --- a/payments/db/payment.go +++ b/payments/db/payment.go @@ -298,6 +298,12 @@ type HTLCFailInfo struct { // field will be populated when the failure reason is either // HTLCFailMessage or HTLCFailUnknown. FailureSourceIndex uint32 + + // HoldTimes is an array of hold times (in 100ms units) as reported by + // nodes along the route via attributable errors. The first element + // corresponds to the first hop after the sender. This field is only + // populated when the error includes attribution data. + HoldTimes []uint32 } // MPPaymentState wraps a series of info needed for a given payment, which is diff --git a/peer/brontide.go b/peer/brontide.go index d624ac9516c..e97ffae1563 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -19,6 +19,7 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btclog/v2" + sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/actor" "github.com/lightningnetwork/lnd/aliasmgr" "github.com/lightningnetwork/lnd/brontide" @@ -1471,9 +1472,30 @@ func (p *Brontide) addLink(chanPoint *wire.OutPoint, //nolint:ll linkCfg := htlcswitch.ChannelLinkConfig{ - Peer: p, - DecodeHopIterators: p.cfg.SphinxPayment.DecodeHopIterators, - ExtractErrorEncrypter: p.cfg.SphinxPayment.ExtractErrorEncrypter, + Peer: p, + DecodeHopIterators: p.cfg.SphinxPayment.DecodeHopIterators, + ExtractSharedSecret: p.cfg.SphinxPayment.ExtractSharedSecret, + CreateErrorEncrypter: func(ephemeralKey *btcec.PublicKey, + sharedSecret sphinx.Hash256, isIntroduction, + hasBlindingPoint bool) hop.ErrorEncrypter { + + switch { + case isIntroduction: + return hop.NewIntroductionErrorEncrypter( + ephemeralKey, sharedSecret, + ) + + case hasBlindingPoint: + return hop.NewRelayingErrorEncrypter( + ephemeralKey, sharedSecret, + ) + + default: + return hop.NewSphinxErrorEncrypter( + ephemeralKey, sharedSecret, + ) + } + }, FetchLastChannelUpdate: p.cfg.FetchLastChanUpdate, HodlMask: p.cfg.Hodl.Mask(), Registry: p.cfg.Invoices, diff --git a/routing/control_tower_test.go b/routing/control_tower_test.go index 697770ff506..6927fd4b375 100644 --- a/routing/control_tower_test.go +++ b/routing/control_tower_test.go @@ -484,7 +484,7 @@ func testKVStoreSubscribeFail(t *testing.T, registerAttempt bool) { if err != nil { t.Fatalf("unable to fail htlc: %v", err) } - if *htlcAttempt.Failure != failInfo { + if !reflect.DeepEqual(*htlcAttempt.Failure, failInfo) { t.Fatalf("unexpected fail info returned") } } diff --git a/routing/payment_lifecycle.go b/routing/payment_lifecycle.go index 488df5b796e..09ab42c48f7 100644 --- a/routing/payment_lifecycle.go +++ b/routing/payment_lifecycle.go @@ -7,7 +7,6 @@ import ( "time" "github.com/btcsuite/btcd/btcec/v2" - sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch" @@ -558,9 +557,7 @@ func (p *paymentLifecycle) collectResult( // Using the created circuit, initialize the error decrypter, so we can // parse+decode any failures incurred by this payment within the // switch. - errorDecryptor := &htlcswitch.SphinxErrorDecrypter{ - OnionErrorDecrypter: sphinx.NewOnionErrorDecrypter(circuit), - } + errorDecryptor := htlcswitch.NewSphinxErrorDecrypter(circuit) // Now ask the switch to return the result of the payment when // available. @@ -1099,6 +1096,7 @@ func marshallError(sendError error, time time.Time) *paymentsdb.HTLCFailInfo { ok = errors.As(rtErr, &fErr) if ok { response.FailureSourceIndex = uint32(fErr.FailureSourceIdx) + response.HoldTimes = fErr.HoldTimes } return response diff --git a/routing/router_test.go b/routing/router_test.go index e9c9ed94591..8b68267869f 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -319,7 +319,7 @@ func TestSendPaymentRouteFailureFallback(t *testing.T) { // TODO(roasbeef): temp node failure // should be? &lnwire.FailTemporaryChannelFailure{}, - 1, + 1, nil, ) } @@ -390,7 +390,7 @@ func TestSendPaymentRouteInfiniteLoopWithBadHopHint(t *testing.T) { // the bad channel is the first hop. badShortChanID := lnwire.NewShortChanIDFromInt(badChannelID) newFwdError := htlcswitch.NewForwardingError( - &lnwire.FailUnknownNextPeer{}, 0, + &lnwire.FailUnknownNextPeer{}, 0, nil, ) payer, ok := ctx.router.cfg.Payer.(*mockPaymentAttemptDispatcherOld) @@ -511,7 +511,7 @@ func TestChannelUpdateValidation(t *testing.T) { &lnwire.FailFeeInsufficient{ Update: errChanUpdate, }, - 1, + 1, nil, ) }) @@ -633,7 +633,7 @@ func TestSendPaymentErrorRepeatedFeeInsufficient(t *testing.T) { // node/channel. &lnwire.FailFeeInsufficient{ Update: errChanUpdate, - }, 1, + }, 1, nil, ) } @@ -744,7 +744,7 @@ func TestSendPaymentErrorFeeInsufficientPrivateEdge(t *testing.T) { // node/channel. &lnwire.FailFeeInsufficient{ Update: errChanUpdate, - }, 1, + }, 1, nil, ) }, ) @@ -872,7 +872,7 @@ func TestSendPaymentPrivateEdgeUpdateFeeExceedsLimit(t *testing.T) { // node/channel. &lnwire.FailFeeInsufficient{ Update: errChanUpdate, - }, 1, + }, 1, nil, ) }, ) @@ -973,7 +973,7 @@ func TestSendPaymentErrorNonFinalTimeLockErrors(t *testing.T) { return [32]byte{}, htlcswitch.NewForwardingError( &lnwire.FailExpiryTooSoon{ Update: errChanUpdate, - }, 1, + }, 1, nil, ) } @@ -1023,7 +1023,7 @@ func TestSendPaymentErrorNonFinalTimeLockErrors(t *testing.T) { return [32]byte{}, htlcswitch.NewForwardingError( &lnwire.FailIncorrectCltvExpiry{ Update: errChanUpdate, - }, 1, + }, 1, nil, ) } @@ -1081,7 +1081,7 @@ func TestSendPaymentErrorPathPruning(t *testing.T) { // sophon not having enough capacity. return [32]byte{}, htlcswitch.NewForwardingError( &lnwire.FailTemporaryChannelFailure{}, - 1, + 1, nil, ) } @@ -1090,7 +1090,7 @@ func TestSendPaymentErrorPathPruning(t *testing.T) { // which should prune out the rest of the routes. if firstHop == roasbeefPhanNuwen { return [32]byte{}, htlcswitch.NewForwardingError( - &lnwire.FailUnknownNextPeer{}, 1, + &lnwire.FailUnknownNextPeer{}, 1, nil, ) } @@ -1139,7 +1139,7 @@ func TestSendPaymentErrorPathPruning(t *testing.T) { if firstHop == roasbeefSongoku { failure := htlcswitch.NewForwardingError( - &lnwire.FailUnknownNextPeer{}, 1, + &lnwire.FailUnknownNextPeer{}, 1, nil, ) return [32]byte{}, failure } @@ -1184,7 +1184,7 @@ func TestSendPaymentErrorPathPruning(t *testing.T) { // roasbeef not having enough capacity. return [32]byte{}, htlcswitch.NewForwardingError( &lnwire.FailTemporaryChannelFailure{}, - 1, + 1, nil, ) } return preImage, nil @@ -1433,7 +1433,7 @@ func TestSendToRouteStructuredError(t *testing.T) { ctx.router.cfg.Payer.(*mockPaymentAttemptDispatcherOld).setPaymentResult( func(firstHop lnwire.ShortChannelID) ([32]byte, error) { return [32]byte{}, htlcswitch.NewForwardingError( - errorType, failIndex, + errorType, failIndex, nil, ) }, ) @@ -2361,7 +2361,7 @@ func TestSendToRouteSkipTempErrTempFailure(t *testing.T) { // Create the error to be returned. tempErr := htlcswitch.NewForwardingError( - &lnwire.FailTemporaryChannelFailure{}, 1, + &lnwire.FailTemporaryChannelFailure{}, 1, nil, ) // Register mockers with the expected method calls. @@ -2441,7 +2441,7 @@ func TestSendToRouteSkipTempErrPermanentFailure(t *testing.T) { // Create the error to be returned. permErr := htlcswitch.NewForwardingError( - &lnwire.FailIncorrectDetails{}, 1, + &lnwire.FailIncorrectDetails{}, 1, nil, ) // Register mockers with the expected method calls. @@ -2527,7 +2527,7 @@ func TestSendToRouteTempFailure(t *testing.T) { // Create the error to be returned. tempErr := htlcswitch.NewForwardingError( - &lnwire.FailTemporaryChannelFailure{}, 1, + &lnwire.FailTemporaryChannelFailure{}, 1, nil, ) // Register mockers with the expected method calls. diff --git a/server.go b/server.go index 0359dc7e3a8..b9bd9d36488 100644 --- a/server.go +++ b/server.go @@ -817,7 +817,7 @@ func newServer(ctx context.Context, cfg *Config, listenAddrs []net.Addr, }, FwdingLog: dbs.ChanStateDB.ForwardingLog(), SwitchPackager: channeldb.NewSwitchPackager(), - ExtractErrorEncrypter: s.sphinxPayment.ExtractErrorEncrypter, + ExtractSharedSecret: s.sphinxPayment.ExtractSharedSecret, FetchLastChannelUpdate: s.fetchLastChanUpdate(), Notifier: s.cc.ChainNotifier, HtlcNotifier: s.htlcNotifier,