diff --git a/go.mod b/go.mod index 88741ad8d75..739e2746dce 100644 --- a/go.mod +++ b/go.mod @@ -62,6 +62,7 @@ require ( google.golang.org/protobuf v1.33.0 gopkg.in/macaroon-bakery.v2 v2.0.1 gopkg.in/macaroon.v2 v2.0.0 + pgregory.net/rapid v1.1.0 ) require ( diff --git a/go.sum b/go.sum index 09e515c0689..4931c5e435b 100644 --- a/go.sum +++ b/go.sum @@ -1070,6 +1070,8 @@ modernc.org/strutil v1.2.0 h1:agBi9dp1I+eOnxXeiZawM8F4LawKv4NzGWSaLfyeNZA= modernc.org/strutil v1.2.0/go.mod h1:/mdcBmfOibveCTBxUl5B5l6W+TTH1FXPLHZE6bTosX0= modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= +pgregory.net/rapid v1.1.0 h1:CMa0sjHSru3puNx+J0MIAuiiEV4N0qj8/cMWGBBCsjw= +pgregory.net/rapid v1.1.0/go.mod h1:PY5XlDGj0+V1FCq0o192FdRhpKHGTRIWBgqjDBTrq04= rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8= rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0= rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA= diff --git a/htlcswitch/link.go b/htlcswitch/link.go index 468eba13198..a59cedf06f3 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -2164,11 +2164,21 @@ func (l *channelLink) handleUpstreamMsg(msg lnwire.Message) { // We just received a new updates to our local commitment // chain, validate this new commitment, closing the link if // invalid. + auxSigBlob, err := msg.CustomRecords.Serialize() + if err != nil { + l.fail( + LinkFailureError{code: ErrInternalError}, + "unable to serialize custom records: %v", + err, + ) + + return + } err = l.channel.ReceiveNewCommitment(&lnwallet.CommitSigs{ CommitSig: msg.CommitSig, HtlcSigs: msg.HtlcSigs, PartialSig: msg.PartialSig, - AuxSigBlob: msg.ExtraData, + AuxSigBlob: auxSigBlob, }) if err != nil { // If we were unable to reconstruct their proposed @@ -2577,12 +2587,17 @@ func (l *channelLink) updateCommitTx() error { default: } + auxBlobRecords, err := lnwire.ParseCustomRecords(newCommit.AuxSigBlob) + if err != nil { + return fmt.Errorf("error parsing aux sigs: %w", err) + } + commitSig := &lnwire.CommitSig{ - ChanID: l.ChanID(), - CommitSig: newCommit.CommitSig, - HtlcSigs: newCommit.HtlcSigs, - PartialSig: newCommit.PartialSig, - ExtraData: newCommit.AuxSigBlob, + ChanID: l.ChanID(), + CommitSig: newCommit.CommitSig, + HtlcSigs: newCommit.HtlcSigs, + PartialSig: newCommit.PartialSig, + CustomRecords: auxBlobRecords, } l.cfg.Peer.SendMessage(false, commitSig) diff --git a/lnwallet/channel.go b/lnwallet/channel.go index fc279c264a9..b420b6134c0 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -4028,6 +4028,10 @@ func (lc *LightningChannel) createCommitDiff(newCommit *commitment, if err != nil { return nil, fmt.Errorf("error packing aux sigs: %w", err) } + auxBlobRecords, err := lnwire.ParseCustomRecords(auxSigBlob) + if err != nil { + return nil, fmt.Errorf("error parsing aux sigs: %w", err) + } return &channeldb.CommitDiff{ Commitment: *diskCommit, @@ -4035,9 +4039,9 @@ func (lc *LightningChannel) createCommitDiff(newCommit *commitment, ChanID: lnwire.NewChanIDFromOutPoint( lc.channelState.FundingOutpoint, ), - CommitSig: commitSig, - HtlcSigs: htlcSigs, - ExtraData: auxSigBlob, + CommitSig: commitSig, + HtlcSigs: htlcSigs, + CustomRecords: auxBlobRecords, }, LogUpdates: logUpdates, OpenedCircuitKeys: openCircuitKeys, @@ -4737,17 +4741,44 @@ func (lc *LightningChannel) SignNextCommitment() (*NewCommitState, error) { // latest commitment update. lc.remoteCommitChain.addCommitment(newCommitView) + auxSigBlob, err := commitDiff.CommitSig.CustomRecords.Serialize() + if err != nil { + return nil, fmt.Errorf("unable to serialize aux sig "+ + "blob: %v", err) + } + return &NewCommitState{ CommitSigs: &CommitSigs{ CommitSig: sig, HtlcSigs: htlcSigs, PartialSig: lnwire.MaybePartialSigWithNonce(partialSig), - AuxSigBlob: commitDiff.CommitSig.ExtraData, + AuxSigBlob: auxSigBlob, }, PendingHTLCs: commitDiff.Commitment.Htlcs, }, nil } +// resignMusigCommit is used to resign a commitment transaction for taproot +// channels when we need to retransmit a signature after a channel reestablish +// message. Taproot channels use musig2, which means we must use fresh nonces +// each time. After we receive the channel reestablish message, we learn the +// nonce we need to use for the remote party. As a result, we need to generate +// the partial signature again with the new nonce. +func (lc *LightningChannel) resignMusigCommit(commitTx *wire.MsgTx, +) (lnwire.OptPartialSigWithNonceTLV, error) { + + remoteSession := lc.musigSessions.RemoteSession + musig, err := remoteSession.SignCommit(commitTx) + if err != nil { + var none lnwire.OptPartialSigWithNonceTLV + return none, err + } + + partialSig := lnwire.MaybePartialSigWithNonce(musig.ToWireSig()) + + return partialSig, nil +} + // ProcessChanSyncMsg processes a ChannelReestablish message sent by the remote // connection upon re establishment of our connection with them. This method // will return a single message if we are currently out of sync, otherwise a @@ -4939,13 +4970,23 @@ func (lc *LightningChannel) ProcessChanSyncMsg( // If we signed this state, then we'll accumulate // another update to send over. case err == nil: + blobRecords, err := lnwire.ParseCustomRecords( + newCommit.AuxSigBlob, + ) + if err != nil { + sErr := fmt.Errorf("error parsing "+ + "aux sigs: %w", err) + return nil, nil, nil, sErr + } + commitSig := &lnwire.CommitSig{ ChanID: lnwire.NewChanIDFromOutPoint( lc.channelState.FundingOutpoint, ), - CommitSig: newCommit.CommitSig, - HtlcSigs: newCommit.HtlcSigs, - PartialSig: newCommit.PartialSig, + CommitSig: newCommit.CommitSig, + HtlcSigs: newCommit.HtlcSigs, + PartialSig: newCommit.PartialSig, + CustomRecords: blobRecords, } updates = append(updates, commitSig) @@ -5025,12 +5066,23 @@ func (lc *LightningChannel) ProcessChanSyncMsg( commitUpdates = append(commitUpdates, logUpdate.UpdateMsg) } + // If this is a taproot channel, then we need to regenerate the + // musig2 signature for the remote party, using their fresh + // nonce. + if lc.channelState.ChanType.IsTaproot() { + partialSig, err := lc.resignMusigCommit( + commitDiff.Commitment.CommitTx, + ) + if err != nil { + return nil, nil, nil, err + } + + commitDiff.CommitSig.PartialSig = partialSig + } + // With the batch of updates accumulated, we'll now re-send the // original CommitSig message required to re-sync their remote // commitment chain with our local version of their chain. - // - // TODO(roasbeef): need to re-sign commitment states w/ - // fresh nonce commitUpdates = append(commitUpdates, commitDiff.CommitSig) // NOTE: If a revocation is not owed, then updates is empty. @@ -5500,9 +5552,9 @@ func genHtlcSigValidationJobs(chanState *channeldb.OpenChannel, // store in the custom records map so we can write to // disk later. sigType := htlcCustomSigType.TypeVal() - htlc.CustomRecords[uint64(sigType)] = auxSig.UnwrapOr( - nil, - ) + auxSig.WhenSome(func(sigB tlv.Blob) { + htlc.CustomRecords[uint64(sigType)] = sigB + }) auxVerifyJobs = append(auxVerifyJobs, auxVerifyJob) } diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go index 9b93ef6ac99..4d972ece49a 100644 --- a/lnwallet/channel_test.go +++ b/lnwallet/channel_test.go @@ -28,6 +28,7 @@ import ( "github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/tlv" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) @@ -3025,19 +3026,11 @@ func restartChannel(channelOld *LightningChannel) (*LightningChannel, error) { return channelNew, nil } -// TestChanSyncOweCommitment tests that if Bob restarts (and then Alice) before -// he receives Alice's CommitSig message, then Alice concludes that she needs -// to re-send the CommitDiff. After the diff has been sent, both nodes should -// resynchronize and be able to complete the dangling commit. -func TestChanSyncOweCommitment(t *testing.T) { - t.Parallel() - +func testChanSyncOweCommitment(t *testing.T, chanType channeldb.ChannelType) { // Create a test channel which will be used for the duration of this // unittest. The channel will be funded evenly with Alice having 5 BTC, // and Bob having 5 BTC. - aliceChannel, bobChannel, err := CreateTestChannels( - t, channeldb.SingleFunderTweaklessBit, - ) + aliceChannel, bobChannel, err := CreateTestChannels(t, chanType) require.NoError(t, err, "unable to create test channels") var fakeOnionBlob [lnwire.OnionPacketSize]byte @@ -3112,6 +3105,15 @@ func TestChanSyncOweCommitment(t *testing.T) { aliceNewCommit, err := aliceChannel.SignNextCommitment() require.NoError(t, err, "unable to sign commitment") + // If this is a taproot channel, then we'll generate fresh verification + // nonce for both sides. + if chanType.IsTaproot() { + _, err = aliceChannel.GenMusigNonces() + require.NoError(t, err) + _, err = bobChannel.GenMusigNonces() + require.NoError(t, err) + } + // Bob doesn't get this message so upon reconnection, they need to // synchronize. Alice should conclude that she owes Bob a commitment, // while Bob should think he's properly synchronized. @@ -3123,7 +3125,7 @@ func TestChanSyncOweCommitment(t *testing.T) { // This is a helper function that asserts Alice concludes that she // needs to retransmit the exact commitment that we failed to send // above. - assertAliceCommitRetransmit := func() { + assertAliceCommitRetransmit := func() *lnwire.CommitSig { aliceMsgsToSend, _, _, err := aliceChannel.ProcessChanSyncMsg( bobSyncMsg, ) @@ -3188,12 +3190,25 @@ func TestChanSyncOweCommitment(t *testing.T) { len(commitSigMsg.HtlcSigs)) } for i, htlcSig := range commitSigMsg.HtlcSigs { - if htlcSig != aliceNewCommit.HtlcSigs[i] { + if !bytes.Equal(htlcSig.RawBytes(), + aliceNewCommit.HtlcSigs[i].RawBytes()) { + t.Fatalf("htlc sig msgs don't match: "+ - "expected %x got %x", - aliceNewCommit.HtlcSigs[i], htlcSig) + "expected %v got %v", + spew.Sdump(aliceNewCommit.HtlcSigs[i]), + spew.Sdump(htlcSig)) } } + + // If this is a taproot channel, then partial sig information + // should be present in the commit sig sent over. This + // signature will be re-regenerated, so we can't compare it + // with the old one. + if chanType.IsTaproot() { + require.True(t, commitSigMsg.PartialSig.IsSome()) + } + + return commitSigMsg } // Alice should detect that she needs to re-send 5 messages: the 3 @@ -3214,14 +3229,19 @@ func TestChanSyncOweCommitment(t *testing.T) { // send the exact same set of messages. aliceChannel, err = restartChannel(aliceChannel) require.NoError(t, err, "unable to restart alice") - assertAliceCommitRetransmit() - // TODO(roasbeef): restart bob as well??? + // To properly simulate a restart, we'll use the *new* signature that + // would send in an actual p2p setting. + aliceReCommitSig := assertAliceCommitRetransmit() // At this point, we should be able to resume the prior state update // without any issues, resulting in Alice settling the 3 htlc's, and // adding one of her own. - err = bobChannel.ReceiveNewCommitment(aliceNewCommit.CommitSigs) + err = bobChannel.ReceiveNewCommitment(&CommitSigs{ + CommitSig: aliceReCommitSig.CommitSig, + HtlcSigs: aliceReCommitSig.HtlcSigs, + PartialSig: aliceReCommitSig.PartialSig, + }) require.NoError(t, err, "bob unable to process alice's commitment") bobRevocation, _, _, err := bobChannel.RevokeCurrentCommitment() require.NoError(t, err, "unable to revoke bob commitment") @@ -3308,16 +3328,147 @@ func TestChanSyncOweCommitment(t *testing.T) { } } -// TestChanSyncOweCommitmentPendingRemote asserts that local updates are applied -// to the remote commit across restarts. -func TestChanSyncOweCommitmentPendingRemote(t *testing.T) { +// TestChanSyncOweCommitment tests that if Bob restarts (and then Alice) before +// he receives Alice's CommitSig message, then Alice concludes that she needs +// to re-send the CommitDiff. After the diff has been sent, both nodes should +// resynchronize and be able to complete the dangling commit. +func TestChanSyncOweCommitment(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + chanType channeldb.ChannelType + }{ + { + name: "tweakless", + chanType: channeldb.SingleFunderTweaklessBit, + }, + { + name: "anchors", + chanType: channeldb.SingleFunderTweaklessBit | + channeldb.AnchorOutputsBit, + }, + { + name: "taproot", + chanType: channeldb.SingleFunderTweaklessBit | + channeldb.AnchorOutputsBit | + channeldb.SimpleTaprootFeatureBit, + }, + { + name: "taproot with tapscript root", + chanType: channeldb.SingleFunderTweaklessBit | + channeldb.AnchorOutputsBit | + channeldb.SimpleTaprootFeatureBit | + channeldb.TapscriptRootBit, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + testChanSyncOweCommitment(t, tc.chanType) + }) + } +} + +type testSigBlob struct { + BlobInt tlv.RecordT[tlv.TlvType65634, uint16] +} + +// TestChanSyncOweCommitmentAuxSigner tests that when one party owes a +// signature after a channel reest, if an aux signer is present, then the +// signature message sent includes the additional aux sigs as extra data. +func TestChanSyncOweCommitmentAuxSigner(t *testing.T) { t.Parallel() // Create a test channel which will be used for the duration of this - // unittest. - aliceChannel, bobChannel, err := CreateTestChannels( - t, channeldb.SingleFunderTweaklessBit, + // unittest. The channel will be funded evenly with Alice having 5 BTC, + // and Bob having 5 BTC. + chanType := channeldb.SingleFunderTweaklessBit | + channeldb.AnchorOutputsBit | channeldb.SimpleTaprootFeatureBit | + channeldb.TapscriptRootBit + + aliceChannel, bobChannel, err := CreateTestChannels(t, chanType) + require.NoError(t, err, "unable to create test channels") + + // We'll now manually attach an aux signer to Alice's channel. + auxSigner := &auxSignerMock{} + aliceChannel.auxSigner = fn.Some[AuxSigner](auxSigner) + + var fakeOnionBlob [lnwire.OnionPacketSize]byte + copy( + fakeOnionBlob[:], + bytes.Repeat([]byte{0x05}, lnwire.OnionPacketSize), ) + + // To kick things off, we'll have Alice send a single HTLC to Bob. + htlcAmt := lnwire.NewMSatFromSatoshis(20000) + var bobPreimage [32]byte + copy(bobPreimage[:], bytes.Repeat([]byte{0}, 32)) + rHash := sha256.Sum256(bobPreimage[:]) + h := &lnwire.UpdateAddHTLC{ + PaymentHash: rHash, + Amount: htlcAmt, + Expiry: uint32(10), + OnionBlob: fakeOnionBlob, + } + + _, err = aliceChannel.AddHTLC(h, nil) + require.NoError(t, err, "unable to recv bob's htlc: %v", err) + + // We'll set up the mock to expect calls to PackSigs and also + // SubmitSubmitSecondLevelSigBatch. + var sigBlobBuf bytes.Buffer + sigBlob := testSigBlob{ + BlobInt: tlv.NewPrimitiveRecord[tlv.TlvType65634, uint16](5), + } + tlvStream, err := tlv.NewStream(sigBlob.BlobInt.Record()) + require.NoError(t, err, "unable to create tlv stream") + require.NoError(t, tlvStream.Encode(&sigBlobBuf)) + + auxSigner.On( + "SubmitSecondLevelSigBatch", mock.Anything, mock.Anything, + mock.Anything, + ).Return(nil).Twice() + auxSigner.On( + "PackSigs", mock.Anything, + ).Return( + fn.Some(sigBlobBuf.Bytes()), nil, + ) + + _, err = aliceChannel.SignNextCommitment() + require.NoError(t, err, "unable to sign commitment") + + _, err = aliceChannel.GenMusigNonces() + require.NoError(t, err, "unable to generate musig nonces") + + // Next we'll simulate a restart, by having Bob send over a chan sync + // message to Alice. + bobSyncMsg, err := bobChannel.channelState.ChanSyncMsg() + require.NoError(t, err, "unable to produce chan sync msg") + + aliceMsgsToSend, _, _, err := aliceChannel.ProcessChanSyncMsg( + bobSyncMsg, + ) + require.NoError(t, err) + require.Len(t, aliceMsgsToSend, 2) + + // The first message should be an update add HTLC. + require.IsType(t, &lnwire.UpdateAddHTLC{}, aliceMsgsToSend[0]) + + // The second should be a commit sig message. + sigMsg, ok := aliceMsgsToSend[1].(*lnwire.CommitSig) + require.True(t, ok) + require.True(t, sigMsg.PartialSig.IsSome()) + + // The signature should have the CustomRecords field set. + require.NotEmpty(t, sigMsg.CustomRecords) +} + +func testChanSyncOweCommitmentPendingRemote(t *testing.T, + chanType channeldb.ChannelType) { + + // Create a test channel which will be used for the duration of this + // unittest. + aliceChannel, bobChannel, err := CreateTestChannels(t, chanType) require.NoError(t, err, "unable to create test channels") var fakeOnionBlob [lnwire.OnionPacketSize]byte @@ -3400,6 +3551,12 @@ func TestChanSyncOweCommitmentPendingRemote(t *testing.T) { bobChannel, err = restartChannel(bobChannel) require.NoError(t, err, "unable to restart bob") + // If this is a taproot channel, then since Bob just restarted, we need + // to exchange nonces once again. + if chanType.IsTaproot() { + require.NoError(t, initMusigNonce(aliceChannel, bobChannel)) + } + // Bob signs the commitment he owes. bobNewCommit, err := bobChannel.SignNextCommitment() require.NoError(t, err, "unable to sign commitment") @@ -3425,6 +3582,45 @@ func TestChanSyncOweCommitmentPendingRemote(t *testing.T) { } } +// TestChanSyncOweCommitmentPendingRemote asserts that local updates are applied +// to the remote commit across restarts. +func TestChanSyncOweCommitmentPendingRemote(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + chanType channeldb.ChannelType + }{ + { + name: "tweakless", + chanType: channeldb.SingleFunderTweaklessBit, + }, + { + name: "anchors", + chanType: channeldb.SingleFunderTweaklessBit | + channeldb.AnchorOutputsBit, + }, + { + name: "taproot", + chanType: channeldb.SingleFunderTweaklessBit | + channeldb.AnchorOutputsBit | + channeldb.SimpleTaprootFeatureBit, + }, + { + name: "taproot with tapscript root", + chanType: channeldb.SingleFunderTweaklessBit | + channeldb.AnchorOutputsBit | + channeldb.SimpleTaprootFeatureBit | + channeldb.TapscriptRootBit, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + testChanSyncOweCommitmentPendingRemote(t, tc.chanType) + }) + } +} + // testChanSyncOweRevocation is the internal version of // TestChanSyncOweRevocation that is parameterized based on the type of channel // being used in the test. @@ -3574,8 +3770,6 @@ func testChanSyncOweRevocation(t *testing.T, chanType channeldb.ChannelType) { assertAliceOwesRevoke() - // TODO(roasbeef): restart bob too??? - // We'll continue by then allowing bob to process Alice's revocation // message. _, _, _, _, err = bobChannel.ReceiveRevocation(aliceRevocation) @@ -3624,6 +3818,23 @@ func TestChanSyncOweRevocation(t *testing.T) { testChanSyncOweRevocation(t, taprootBits) }) + t.Run("taproot", func(t *testing.T) { + taprootBits := channeldb.SimpleTaprootFeatureBit | + channeldb.AnchorOutputsBit | + channeldb.ZeroHtlcTxFeeBit | + channeldb.SingleFunderTweaklessBit + + testChanSyncOweRevocation(t, taprootBits) + }) + t.Run("taproot with tapscript root", func(t *testing.T) { + taprootBits := channeldb.SimpleTaprootFeatureBit | + channeldb.AnchorOutputsBit | + channeldb.ZeroHtlcTxFeeBit | + channeldb.SingleFunderTweaklessBit | + channeldb.TapscriptRootBit + + testChanSyncOweRevocation(t, taprootBits) + }) } func testChanSyncOweRevocationAndCommit(t *testing.T, @@ -3753,6 +3964,14 @@ func testChanSyncOweRevocationAndCommit(t *testing.T, bobNewCommit.HtlcSigs[i]) } } + + // If this is a taproot channel, then partial sig information + // should be present in the commit sig sent over. This + // signature will be re-regenerated, so we can't compare it + // with the old one. + if chanType.IsTaproot() { + require.True(t, bobReCommitSigMsg.PartialSig.IsSome()) + } } // We expect Bob to send exactly two messages: first his revocation @@ -3809,6 +4028,15 @@ func TestChanSyncOweRevocationAndCommit(t *testing.T) { testChanSyncOweRevocationAndCommit(t, taprootBits) }) + t.Run("taproot with tapscript root", func(t *testing.T) { + taprootBits := channeldb.SimpleTaprootFeatureBit | + channeldb.AnchorOutputsBit | + channeldb.ZeroHtlcTxFeeBit | + channeldb.SingleFunderTweaklessBit | + channeldb.TapscriptRootBit + + testChanSyncOweRevocationAndCommit(t, taprootBits) + }) } func testChanSyncOweRevocationAndCommitForceTransition(t *testing.T, @@ -4040,6 +4268,17 @@ func TestChanSyncOweRevocationAndCommitForceTransition(t *testing.T) { t, taprootBits, ) }) + t.Run("taproot with tapscript root", func(t *testing.T) { + taprootBits := channeldb.SimpleTaprootFeatureBit | + channeldb.AnchorOutputsBit | + channeldb.ZeroHtlcTxFeeBit | + channeldb.SingleFunderTweaklessBit | + channeldb.TapscriptRootBit + + testChanSyncOweRevocationAndCommitForceTransition( + t, taprootBits, + ) + }) } // TestChanSyncFailure tests the various scenarios during channel sync where we diff --git a/lnwallet/mock.go b/lnwallet/mock.go index 1873de79a84..89c31ad9857 100644 --- a/lnwallet/mock.go +++ b/lnwallet/mock.go @@ -17,7 +17,11 @@ import ( "github.com/btcsuite/btcwallet/wallet/txauthor" "github.com/btcsuite/btcwallet/wtxmgr" "github.com/lightningnetwork/lnd/chainntnfs" + "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/lnwallet/chainfee" + "github.com/lightningnetwork/lnd/tlv" + "github.com/stretchr/testify/mock" ) var ( @@ -384,3 +388,46 @@ func (*mockChainIO) GetBlockHeader( return nil, nil } + +type auxSignerMock struct { + mock.Mock +} + +func (a *auxSignerMock) SubmitSecondLevelSigBatch( + chanState *channeldb.OpenChannel, + commitTx *wire.MsgTx, sigJobs []AuxSigJob) error { + + args := a.Called(chanState, commitTx, sigJobs) + + // While we return, we'll also send back an instant response for the + // set of jobs. + for _, sigJob := range sigJobs { + sigJob.Resp <- AuxSigJobResp{} + } + + return args.Error(0) +} + +func (a *auxSignerMock) PackSigs(sigs []fn.Option[tlv.Blob], +) (fn.Option[tlv.Blob], error) { + + args := a.Called(sigs) + + return args.Get(0).(fn.Option[tlv.Blob]), args.Error(1) +} + +func (a *auxSignerMock) UnpackSigs(sigs fn.Option[tlv.Blob]) ( + []fn.Option[tlv.Blob], error) { + + args := a.Called(sigs) + + return args.Get(0).([]fn.Option[tlv.Blob]), args.Error(1) +} + +func (a *auxSignerMock) VerifySecondLevelSigs(chanState *channeldb.OpenChannel, + commitTx *wire.MsgTx, verifyJob []AuxVerifyJob) error { + + args := a.Called(chanState, commitTx, verifyJob) + + return args.Error(0) +} diff --git a/lnwallet/test_utils.go b/lnwallet/test_utils.go index d7ac5df3e48..8ff8860e969 100644 --- a/lnwallet/test_utils.go +++ b/lnwallet/test_utils.go @@ -431,6 +431,25 @@ func CreateTestChannels(t *testing.T, chanType channeldb.ChannelType, return channelAlice, channelBob, nil } +// initMusigNonce is used to manually setup musig2 nonces for a new channel, +// outside the normal chan-reest flow. +func initMusigNonce(chanA, chanB *LightningChannel) error { + chanANonces, err := chanA.GenMusigNonces() + if err != nil { + return err + } + chanBNonces, err := chanB.GenMusigNonces() + if err != nil { + return err + } + + if err := chanA.InitRemoteMusigNonces(chanBNonces); err != nil { + return err + } + + return chanB.InitRemoteMusigNonces(chanANonces) +} + // initRevocationWindows simulates a new channel being opened within the p2p // network by populating the initial revocation windows of the passed // commitment state machines. @@ -439,19 +458,7 @@ func initRevocationWindows(chanA, chanB *LightningChannel) error { // either FundingLocked or ChannelReestablish by calling // InitRemoteMusigNonces for both sides. if chanA.channelState.ChanType.IsTaproot() { - chanANonces, err := chanA.GenMusigNonces() - if err != nil { - return err - } - chanBNonces, err := chanB.GenMusigNonces() - if err != nil { - return err - } - - if err := chanA.InitRemoteMusigNonces(chanBNonces); err != nil { - return err - } - if err := chanB.InitRemoteMusigNonces(chanANonces); err != nil { + if err := initMusigNonce(chanA, chanB); err != nil { return err } } diff --git a/lnwire/commit_sig.go b/lnwire/commit_sig.go index 7deb64ae1c1..c3f1a89b6bb 100644 --- a/lnwire/commit_sig.go +++ b/lnwire/commit_sig.go @@ -2,6 +2,7 @@ package lnwire import ( "bytes" + "fmt" "io" "github.com/lightningnetwork/lnd/tlv" @@ -45,6 +46,11 @@ type CommitSig struct { // being signed for. In this case, the above Sig type MUST be blank. PartialSig OptPartialSigWithNonceTLV + // CustomRecords maps TLV types to byte slices, storing arbitrary data + // intended for inclusion in the ExtraData field of the CommitSig + // message. + CustomRecords CustomRecords + // ExtraData is the set of data that was appended to this message to // fill out the full maximum transport message size. These fields can // be used to specify optional data such as custom TLV fields. @@ -62,8 +68,8 @@ func NewCommitSig() *CommitSig { // interface. var _ Message = (*CommitSig)(nil) -// Decode deserializes a serialized CommitSig message stored in the -// passed io.Reader observing the specified protocol version. +// Decode deserializes a serialized CommitSig message stored in the passed +// io.Reader observing the specified protocol version. // // This is part of the lnwire.Message interface. func (c *CommitSig) Decode(r io.Reader, pver uint32) error { @@ -90,29 +96,57 @@ func (c *CommitSig) Decode(r io.Reader, pver uint32) error { // Set the corresponding TLV types if they were included in the stream. if val, ok := typeMap[c.PartialSig.TlvType()]; ok && val == nil { c.PartialSig = tlv.SomeRecordT(partialSig) + + // Remove the entry from the TLV map. Anything left in the map + // will be included in the custom records field. + delete(typeMap, c.PartialSig.TlvType()) } - if len(tlvRecords) != 0 { - c.ExtraData = tlvRecords + // Parse through the remaining extra data map to separate the custom + // records, from the set of official records. + tlvTypes := newWireTlvMap(typeMap) + + // Set the custom records field to the custom records specific TLV + // record map. + customRecords, err := NewCustomRecordsFromTlvTypeMap( + tlvTypes.customTypes, + ) + if err != nil { + return err + } + c.CustomRecords = customRecords + + // Set custom records to nil if we didn't parse anything out of it so + // that we can use assert.Equal in tests. + if len(customRecords) == 0 { + c.CustomRecords = nil + } + + // Set extra data to nil if we didn't parse anything out of it so that + // we can use assert.Equal in tests. + if len(tlvTypes.officialTypes) == 0 { + c.ExtraData = nil + return nil + } + + // Encode the remaining records back into the extra data field. These + // records are not in the custom records TLV type range and do not have + // associated fields in the CommitSig struct. + c.ExtraData, err = NewExtraOpaqueDataFromTlvTypeMap( + tlvTypes.officialTypes, + ) + if err != nil { + return err } return nil } -// Encode serializes the target CommitSig into the passed io.Writer -// observing the protocol version specified. +// Encode serializes the target CommitSig into the passed io.Writer observing +// the protocol version specified. // // This is part of the lnwire.Message interface. func (c *CommitSig) Encode(w *bytes.Buffer, pver uint32) error { - recordProducers := make([]tlv.RecordProducer, 0, 1) - c.PartialSig.WhenSome(func(sig PartialSigWithNonceTLV) { - recordProducers = append(recordProducers, &sig) - }) - err := EncodeMessageExtraData(&c.ExtraData, recordProducers...) - if err != nil { - return err - } - if err := WriteChannelID(w, c.ChanID); err != nil { return err } @@ -125,7 +159,39 @@ func (c *CommitSig) Encode(w *bytes.Buffer, pver uint32) error { return err } - return WriteBytes(w, c.ExtraData) + // Construct a slice of all the records that we should include in the + // message extra data field. We will start by including any records + // from the extra data field. + msgExtraDataRecords, err := c.ExtraData.RecordProducers() + if err != nil { + return err + } + + // Include the partial sig record if it is set. + c.PartialSig.WhenSome(func(sig PartialSigWithNonceTLV) { + msgExtraDataRecords = append(msgExtraDataRecords, &sig) + }) + + // Include custom records in the extra data wire field if they are + // present. Ensure that the custom records are validated before + // encoding them. + if err := c.CustomRecords.Validate(); err != nil { + return fmt.Errorf("custom records validation error: %w", err) + } + + // Extend the message extra data records slice with TLV records from + // the custom records field. + customTlvRecords := c.CustomRecords.RecordProducers() + msgExtraDataRecords = append(msgExtraDataRecords, customTlvRecords...) + + // We will now construct the message extra data field that will be + // encoded into the byte writer. + var msgExtraData ExtraOpaqueData + if err := msgExtraData.PackRecords(msgExtraDataRecords...); err != nil { + return err + } + + return WriteBytes(w, msgExtraData) } // MsgType returns the integer uniquely identifying this message type on the diff --git a/lnwire/extra_bytes.go b/lnwire/extra_bytes.go index b90988a7711..4b6a953a04d 100644 --- a/lnwire/extra_bytes.go +++ b/lnwire/extra_bytes.go @@ -1,9 +1,15 @@ package lnwire +// For some reason golangci-lint has a false positive on the sort order of the +// imports for the new "maps" package... We need the nolint directive here to +// ignore that. +// +//nolint:gci import ( "bytes" "fmt" "io" + "maps" "github.com/lightningnetwork/lnd/tlv" ) @@ -194,3 +200,57 @@ func EncodeMessageExtraData(extraData *ExtraOpaqueData, // are all properly sorted. return extraData.PackRecords(recordProducers...) } + +// wireTlvMap is a struct that holds the official records and custom records in +// a TLV type map. This is useful for ensuring that the set of custom TLV +// records are handled properly and don't overlap with the official records. +type wireTlvMap struct { + // officialTypes is the set of official records that are defined in the + // spec. + officialTypes tlv.TypeMap + + // customTypes is the set of custom records that are not defined in + // spec, and are used by higher level applications. + customTypes tlv.TypeMap +} + +// newWireTlvMap creates a new tlv.TypeMap from the given set of parsed TLV +// records. A struct with two maps are returned: +// +// 1. officialTypes: the set of official records that are defined in the +// spec. +// +// 2. customTypes: the set of custom records that are not defined in +// the spec. +func newWireTlvMap(typeMap tlv.TypeMap) wireTlvMap { + officialRecords := maps.Clone(typeMap) + + // Any records from the extra data TLV map which are in the custom + // records TLV type range will be included in the custom records field + // and removed from the extra data field. + customRecordsTlvMap := make(tlv.TypeMap, len(typeMap)) + for k, v := range typeMap { + // Skip records that are not in the custom records TLV type + // range. + if k < MinCustomRecordsTlvType { + continue + } + + // Include the record in the custom records map. + customRecordsTlvMap[k] = v + + // Now that the record is included in the custom records map, + // we can remove it from the extra data TLV map. + delete(officialRecords, k) + } + + return wireTlvMap{ + officialTypes: officialRecords, + customTypes: customRecordsTlvMap, + } +} + +// Len returns the total number of records in the wireTlvMap. +func (w *wireTlvMap) Len() int { + return len(w.officialTypes) + len(w.customTypes) +} diff --git a/lnwire/extra_bytes_test.go b/lnwire/extra_bytes_test.go index b05b19db5f8..98c7eeefca1 100644 --- a/lnwire/extra_bytes_test.go +++ b/lnwire/extra_bytes_test.go @@ -7,8 +7,11 @@ import ( "testing" "testing/quick" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/tlv" "github.com/stretchr/testify/require" + "golang.org/x/exp/maps" + "pgregory.net/rapid" ) // TestExtraOpaqueDataEncodeDecode tests that we're able to encode/decode @@ -206,3 +209,47 @@ func TestPackRecords(t *testing.T) { require.Equal(t, recordBytes2, extractedRecords[tlvType2.TypeVal()]) require.Equal(t, recordBytes3, extractedRecords[tlvType3.TypeVal()]) } + +// TestNewWireTlvMap tests the newWireTlvMap function using property-based +// testing. +func TestNewWireTlvMap(t *testing.T) { + rapid.Check(t, func(t *rapid.T) { + // Make a random type map, using the generic Make which'll + // figure out what type to generate. + tlvTypeMap := rapid.Make[tlv.TypeMap]().Draw(t, "typeMap") + + // Create a wireTlvMap from the generated type map, this'll + // operate on our random input. + result := newWireTlvMap(tlvTypeMap) + + // Property 1: The sum of lengths of officialTypes and + // customTypes should equal the length of the input typeMap. + require.Equal(t, len(tlvTypeMap), result.Len()) + + // Property 2: All types in customTypes should be >= + // MinCustomRecordsTlvType. + require.True(t, fn.All(func(k tlv.Type) bool { + return uint64(k) >= uint64(MinCustomRecordsTlvType) + }, maps.Keys(result.customTypes))) + + // Property 3: All types in officialTypes should be < + // MinCustomRecordsTlvType. + require.True(t, fn.All(func(k tlv.Type) bool { + return uint64(k) < uint64(MinCustomRecordsTlvType) + }, maps.Keys(result.officialTypes))) + + // Property 4: The union of officialTypes and customTypes + // should equal the input typeMap. + unionMap := make(tlv.TypeMap) + maps.Copy(unionMap, result.officialTypes) + maps.Copy(unionMap, result.customTypes) + require.Equal(t, tlvTypeMap, unionMap) + + // Property 5: No type should appear in both officialTypes and + // customTypes. + require.True(t, fn.All(func(k tlv.Type) bool { + _, exists := result.officialTypes[k] + return !exists + }, maps.Keys(result.customTypes))) + }) +} diff --git a/lnwire/lnwire_test.go b/lnwire/lnwire_test.go index c9413d4c2f4..5531f883ee2 100644 --- a/lnwire/lnwire_test.go +++ b/lnwire/lnwire_test.go @@ -887,7 +887,7 @@ func TestLightningWireProtocol(t *testing.T) { v[0] = reflect.ValueOf(ks) }, MsgCommitSig: func(v []reflect.Value, r *rand.Rand) { - req := NewCommitSig() + req := &CommitSig{} if _, err := r.Read(req.ChanID[:]); err != nil { t.Fatalf("unable to generate chan id: %v", err) return @@ -915,6 +915,8 @@ func TestLightningWireProtocol(t *testing.T) { } } + req.CustomRecords = randCustomRecords(t, r) + // 50/50 chance to attach a partial sig. if r.Int31()%2 == 0 { req.PartialSig = somePartialSigWithNonce(t, r) @@ -1653,22 +1655,28 @@ func TestLightningWireProtocol(t *testing.T) { }, } for _, test := range tests { - var config *quick.Config - - // If the type defined is within the custom type gen map above, - // then we'll modify the default config to use this Value - // function that knows how to generate the proper types. - if valueGen, ok := customTypeGen[test.msgType]; ok { - config = &quick.Config{ - Values: valueGen, + t.Run(test.msgType.String(), func(t *testing.T) { + var config *quick.Config + + // If the type defined is within the custom type gen + // map above, then we'll modify the default config to + // use this Value function that knows how to generate + // the proper types. + if valueGen, ok := customTypeGen[test.msgType]; ok { + config = &quick.Config{ + Values: valueGen, + } } - } - t.Logf("Running fuzz tests for msgType=%v", test.msgType) - if err := quick.Check(test.scenario, config); err != nil { - t.Fatalf("fuzz checks for msg=%v failed: %v", - test.msgType, err) - } + t.Logf("Running fuzz tests for msgType=%v", + test.msgType) + + err := quick.Check(test.scenario, config) + if err != nil { + t.Fatalf("fuzz checks for msg=%v failed: %v", + test.msgType, err) + } + }) } } diff --git a/lnwire/update_add_htlc.go b/lnwire/update_add_htlc.go index 3669f81e89a..982e66b8330 100644 --- a/lnwire/update_add_htlc.go +++ b/lnwire/update_add_htlc.go @@ -131,29 +131,14 @@ func (c *UpdateAddHTLC) Decode(r io.Reader, pver uint32) error { delete(extraDataTlvMap, c.BlindingPoint.TlvType()) } - // Any records from the extra data TLV map which are in the custom - // records TLV type range will be included in the custom records field - // and removed from the extra data field. - customRecordsTlvMap := make(tlv.TypeMap, len(extraDataTlvMap)) - for k, v := range extraDataTlvMap { - // Skip records that are not in the custom records TLV type - // range. - if k < MinCustomRecordsTlvType { - continue - } - - // Include the record in the custom records map. - customRecordsTlvMap[k] = v - - // Now that the record is included in the custom records map, - // we can remove it from the extra data TLV map. - delete(extraDataTlvMap, k) - } + // Parse through the remaining extra data map to separate the custom + // records, from the set of official records. + tlvTypes := newWireTlvMap(extraDataTlvMap) // Set the custom records field to the custom records specific TLV // record map. customRecords, err := NewCustomRecordsFromTlvTypeMap( - customRecordsTlvMap, + tlvTypes.customTypes, ) if err != nil { return err @@ -162,21 +147,23 @@ func (c *UpdateAddHTLC) Decode(r io.Reader, pver uint32) error { // Set custom records to nil if we didn't parse anything out of it so // that we can use assert.Equal in tests. - if len(customRecordsTlvMap) == 0 { + if len(customRecords) == 0 { c.CustomRecords = nil } // Set extra data to nil if we didn't parse anything out of it so that // we can use assert.Equal in tests. - if len(extraDataTlvMap) == 0 { + if len(tlvTypes.officialTypes) == 0 { c.ExtraData = nil return nil } // Encode the remaining records back into the extra data field. These - // records are not in the custom records TLV type range and do not - // have associated fields in the UpdateAddHTLC struct. - c.ExtraData, err = NewExtraOpaqueDataFromTlvTypeMap(extraDataTlvMap) + // records are not in the custom records TLV type range and do not have + // associated fields in the UpdateAddHTLC struct. + c.ExtraData, err = NewExtraOpaqueDataFromTlvTypeMap( + tlvTypes.officialTypes, + ) if err != nil { return err }