From fac6586de985882fd0318dfe32221eed12ddb954 Mon Sep 17 00:00:00 2001 From: t-bast Date: Thu, 9 Dec 2021 13:48:57 +0100 Subject: [PATCH] Add support for option_shutdown_anysegwit Opt-in to allow any future segwit script in shutdown as long as it complies with BIP 141 (see lightning/bolts#672). This is particularly useful to allow wallet users to close channels to a Taproot address. --- .../kotlin/fr/acinq/lightning/Features.kt | 7 +++ .../fr/acinq/lightning/channel/Channel.kt | 6 ++- .../fr/acinq/lightning/channel/Helpers.kt | 18 +++++-- .../fr/acinq/lightning/FeaturesTestsCommon.kt | 5 +- .../channel/states/NormalTestsCommon.kt | 48 +++++++++++++++++-- 5 files changed, 71 insertions(+), 13 deletions(-) diff --git a/src/commonMain/kotlin/fr/acinq/lightning/Features.kt b/src/commonMain/kotlin/fr/acinq/lightning/Features.kt index 6d6d6f64a..213edc184 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/Features.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/Features.kt @@ -91,6 +91,12 @@ sealed class Feature { override val mandatory get() = 20 } + @Serializable + object ShutdownAnySegwit : Feature() { + override val rfcName get() = "option_shutdown_anysegwit" + override val mandatory get() = 26 + } + @Serializable object ChannelType : Feature() { override val rfcName get() = "option_channel_type" @@ -228,6 +234,7 @@ data class Features(val activated: Map, val unknown: Se Feature.BasicMultiPartPayment, Feature.Wumbo, Feature.AnchorOutputs, + Feature.ShutdownAnySegwit, Feature.ChannelType, Feature.TrampolinePayment, Feature.ZeroReserveChannels, diff --git a/src/commonMain/kotlin/fr/acinq/lightning/channel/Channel.kt b/src/commonMain/kotlin/fr/acinq/lightning/channel/Channel.kt index 30b7fd999..da5d40924 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/channel/Channel.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/channel/Channel.kt @@ -1887,12 +1887,13 @@ data class Normal( } } is CMD_CLOSE -> { + val allowAnySegwit = Features.canUseFeature(commitments.localParams.features, commitments.remoteParams.features, Feature.ShutdownAnySegwit) val localScriptPubkey = event.command.scriptPubKey ?: commitments.localParams.defaultFinalScriptPubKey when { this.localShutdown != null -> handleCommandError(event.command, ClosingAlreadyInProgress(channelId), channelUpdate) this.commitments.localHasUnsignedOutgoingHtlcs() -> handleCommandError(event.command, CannotCloseWithUnsignedOutgoingHtlcs(channelId), channelUpdate) this.commitments.localHasUnsignedOutgoingUpdateFee() -> handleCommandError(event.command, CannotCloseWithUnsignedOutgoingUpdateFee(channelId), channelUpdate) - !Helpers.Closing.isValidFinalScriptPubkey(localScriptPubkey) -> handleCommandError(event.command, InvalidFinalScript(channelId), channelUpdate) + !Helpers.Closing.isValidFinalScriptPubkey(localScriptPubkey, allowAnySegwit) -> handleCommandError(event.command, InvalidFinalScript(channelId), channelUpdate) else -> { val shutdown = Shutdown(channelId, localScriptPubkey) val newState = this.copy(localShutdown = shutdown, closingFeerates = event.command.feerates) @@ -1995,6 +1996,7 @@ data class Normal( } } is Shutdown -> { + val allowAnySegwit = Features.canUseFeature(commitments.localParams.features, commitments.remoteParams.features, Feature.ShutdownAnySegwit) // they have pending unsigned htlcs => they violated the spec, close the channel // they don't have pending unsigned htlcs // we have pending unsigned htlcs @@ -2010,7 +2012,7 @@ data class Normal( // there are pending signed changes => go to SHUTDOWN // there are no changes => go to NEGOTIATING when { - !Helpers.Closing.isValidFinalScriptPubkey(event.message.scriptPubKey) -> handleLocalError(event, InvalidFinalScript(channelId)) + !Helpers.Closing.isValidFinalScriptPubkey(event.message.scriptPubKey, allowAnySegwit) -> handleLocalError(event, InvalidFinalScript(channelId)) commitments.remoteHasUnsignedOutgoingHtlcs() -> handleLocalError(event, CannotCloseWithUnsignedOutgoingHtlcs(channelId)) commitments.remoteHasUnsignedOutgoingUpdateFee() -> handleLocalError(event, CannotCloseWithUnsignedOutgoingUpdateFee(channelId)) commitments.localHasUnsignedOutgoingHtlcs() -> { diff --git a/src/commonMain/kotlin/fr/acinq/lightning/channel/Helpers.kt b/src/commonMain/kotlin/fr/acinq/lightning/channel/Helpers.kt index f123a5cd8..2ba02e4c0 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/channel/Helpers.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/channel/Helpers.kt @@ -384,14 +384,21 @@ object Helpers { // used only to compute tx weights and estimate fees private val dummyPublicKey by lazy { PrivateKey(ByteArray(32) { 1.toByte() }).publicKey() } - private fun isValidFinalScriptPubkey(scriptPubKey: ByteArray): Boolean { + private fun isValidFinalScriptPubkey(scriptPubKey: ByteArray, allowAnySegwit: Boolean): Boolean { return runTrying { val script = Script.parse(scriptPubKey) - Script.isPay2pkh(script) || Script.isPay2sh(script) || Script.isPay2wpkh(script) || Script.isPay2wsh(script) + when { + Script.isPay2pkh(script) -> true + Script.isPay2sh(script) -> true + Script.isPay2wpkh(script) -> true + Script.isPay2wsh(script) -> true + Script.isNativeWitnessScript(script) -> allowAnySegwit + else -> false + } }.getOrElse { false } } - fun isValidFinalScriptPubkey(scriptPubKey: ByteVector): Boolean = isValidFinalScriptPubkey(scriptPubKey.toByteArray()) + fun isValidFinalScriptPubkey(scriptPubKey: ByteVector, allowAnySegwit: Boolean): Boolean = isValidFinalScriptPubkey(scriptPubKey.toByteArray(), allowAnySegwit) // To be replaced with corresponding function in bitcoin-kmp fun btcAddressFromScriptPubKey(scriptPubKey: ByteVector, chainHash: ByteVector32): String? { @@ -465,8 +472,9 @@ object Helpers { remoteScriptPubkey: ByteArray, closingFees: ClosingFees ): Pair { - require(isValidFinalScriptPubkey(localScriptPubkey)) { "invalid localScriptPubkey" } - require(isValidFinalScriptPubkey(remoteScriptPubkey)) { "invalid remoteScriptPubkey" } + val allowAnySegwit = Features.canUseFeature(commitments.localParams.features, commitments.remoteParams.features, Feature.ShutdownAnySegwit) + require(isValidFinalScriptPubkey(localScriptPubkey, allowAnySegwit)) { "invalid localScriptPubkey" } + require(isValidFinalScriptPubkey(remoteScriptPubkey, allowAnySegwit)) { "invalid remoteScriptPubkey" } val dustLimit = commitments.localParams.dustLimit.max(commitments.remoteParams.dustLimit) val closingTx = Transactions.makeClosingTx(commitments.commitInput, localScriptPubkey, remoteScriptPubkey, commitments.localParams.isFunder, dustLimit, closingFees.preferred, commitments.localCommit.spec) val localClosingSig = keyManager.sign(closingTx, commitments.localParams.channelKeys.fundingPrivateKey) diff --git a/src/commonTest/kotlin/fr/acinq/lightning/FeaturesTestsCommon.kt b/src/commonTest/kotlin/fr/acinq/lightning/FeaturesTestsCommon.kt index 22e78d00a..919027da2 100644 --- a/src/commonTest/kotlin/fr/acinq/lightning/FeaturesTestsCommon.kt +++ b/src/commonTest/kotlin/fr/acinq/lightning/FeaturesTestsCommon.kt @@ -208,9 +208,10 @@ class FeaturesTestsCommon : LightningTestSuite() { byteArrayOf(0x09, 0x00, 0x42, 0x00) to Features( mapOf( VariableLengthOnion to FeatureSupport.Optional, - PaymentSecret to FeatureSupport.Mandatory + PaymentSecret to FeatureSupport.Mandatory, + ShutdownAnySegwit to FeatureSupport.Optional ), - setOf(UnknownFeature(24), UnknownFeature(27)) + setOf(UnknownFeature(24)) ), byteArrayOf(0x52, 0x00, 0x00, 0x00) to Features( mapOf(), diff --git a/src/commonTest/kotlin/fr/acinq/lightning/channel/states/NormalTestsCommon.kt b/src/commonTest/kotlin/fr/acinq/lightning/channel/states/NormalTestsCommon.kt index 6b3931d91..b63683cc5 100644 --- a/src/commonTest/kotlin/fr/acinq/lightning/channel/states/NormalTestsCommon.kt +++ b/src/commonTest/kotlin/fr/acinq/lightning/channel/states/NormalTestsCommon.kt @@ -1,11 +1,8 @@ package fr.acinq.lightning.channel.states import fr.acinq.bitcoin.* -import fr.acinq.lightning.CltvExpiry -import fr.acinq.lightning.CltvExpiryDelta -import fr.acinq.lightning.Feature +import fr.acinq.lightning.* import fr.acinq.lightning.Lightning.randomBytes32 -import fr.acinq.lightning.ShortChannelId import fr.acinq.lightning.blockchain.* import fr.acinq.lightning.blockchain.fee.FeeratePerKw import fr.acinq.lightning.channel.* @@ -1413,6 +1410,28 @@ class NormalTestsCommon : LightningTestSuite() { actions1.hasCommandError() } + @Test + fun `recv CMD_CLOSE (with unsupported native segwit script)`() { + val (alice, _) = reachNormal() + assertNull(alice.localShutdown) + val (alice1, actions1) = alice.processEx(ChannelEvent.ExecuteCommand(CMD_CLOSE(ByteVector("51050102030405"), null))) + assertTrue(alice1 is Normal) + actions1.hasCommandError() + } + + @Test + fun `recv CMD_CLOSE (with native segwit script)`() { + val (alice, _) = reachNormal( + aliceFeatures = TestConstants.Alice.nodeParams.features.copy(TestConstants.Alice.nodeParams.features.activated + (Feature.ShutdownAnySegwit to FeatureSupport.Optional)), + bobFeatures = TestConstants.Bob.nodeParams.features.copy(TestConstants.Bob.nodeParams.features.activated + (Feature.ShutdownAnySegwit to FeatureSupport.Optional)), + ) + assertNull(alice.localShutdown) + val (alice1, actions1) = alice.processEx(ChannelEvent.ExecuteCommand(CMD_CLOSE(ByteVector("51050102030405"), null))) + actions1.hasOutgoingMessage() + assertTrue(alice1 is Normal) + assertNotNull(alice1.localShutdown) + } + @Test fun `recv CMD_CLOSE (with signed sent htlcs)`() { val (alice, bob) = reachNormal() @@ -1551,6 +1570,27 @@ class NormalTestsCommon : LightningTestSuite() { actions1.hasWatch() } + @Test + fun `recv Shutdown (with unsupported native segwit script)`() { + val (_, bob) = reachNormal() + val (bob1, actions1) = bob.processEx(ChannelEvent.MessageReceived(Shutdown(bob.channelId, ByteVector("51050102030405")))) + assertTrue(bob1 is Closing) + actions1.hasOutgoingMessage() + assertEquals(2, actions1.filterIsInstance().count()) + actions1.hasWatch() + } + + @Test + fun `recv Shutdown (with native segwit script)`() { + val (_, bob) = reachNormal( + aliceFeatures = TestConstants.Alice.nodeParams.features.copy(TestConstants.Alice.nodeParams.features.activated + (Feature.ShutdownAnySegwit to FeatureSupport.Optional)), + bobFeatures = TestConstants.Bob.nodeParams.features.copy(TestConstants.Bob.nodeParams.features.activated + (Feature.ShutdownAnySegwit to FeatureSupport.Optional)), + ) + val (bob1, actions1) = bob.processEx(ChannelEvent.MessageReceived(Shutdown(bob.channelId, ByteVector("51050102030405")))) + assertTrue(bob1 is Negotiating) + actions1.hasOutgoingMessage() + } + @Test fun `recv Shutdown (with invalid final script and signed htlcs, in response to a Shutdown)`() { val (alice, bob) = reachNormal()