Skip to content

Commit ea8f940

Browse files
authored
Fix potential race condition in node-relay (#1716)
We previously relied on `context.child` to check whether we already had a relay handler for a given payment_hash. Unfortunately this could return an actor that is currently stopping itself. When that happens our relay command can end up in the dead letters and the payment will not be relayed, nor be failed upstream. We fix that by maintaining the list of current relay handlers in the NodeRelayer and removing them from the list before stopping them. This is similar to what's done in the MultiPartPaymentFSM.
1 parent 92e53dc commit ea8f940

File tree

3 files changed

+117
-41
lines changed

3 files changed

+117
-41
lines changed

eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala

+24-13
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616

1717
package fr.acinq.eclair.payment.relay
1818

19-
import java.util.UUID
20-
2119
import akka.actor.ActorRef
2220
import akka.actor.typed.Behavior
2321
import akka.actor.typed.eventstream.EventStream
@@ -41,6 +39,7 @@ import fr.acinq.eclair.router.{BalanceTooLow, RouteCalculation, RouteNotFound}
4139
import fr.acinq.eclair.wire._
4240
import fr.acinq.eclair.{CltvExpiry, Features, Logs, MilliSatoshi, NodeParams, nodeFee, randomBytes32}
4341

42+
import java.util.UUID
4443
import scala.collection.immutable.Queue
4544

4645
/**
@@ -52,6 +51,7 @@ object NodeRelay {
5251
// @formatter:off
5352
sealed trait Command
5453
case class Relay(nodeRelayPacket: IncomingPacket.NodeRelayPacket) extends Command
54+
case object Stop extends Command
5555
private case class WrappedMultiPartExtraPaymentReceived(mppExtraReceived: MultiPartPaymentFSM.ExtraPaymentReceived[HtlcPart]) extends Command
5656
private case class WrappedMultiPartPaymentFailed(mppFailed: MultiPartPaymentFSM.MultiPartPaymentFailed) extends Command
5757
private case class WrappedMultiPartPaymentSucceeded(mppSucceeded: MultiPartPaymentFSM.MultiPartPaymentSucceeded) extends Command
@@ -60,13 +60,13 @@ object NodeRelay {
6060
private case class WrappedPaymentFailed(paymentFailed: PaymentFailed) extends Command
6161
// @formatter:on
6262

63-
def apply(nodeParams: NodeParams, router: ActorRef, register: ActorRef, relayId: UUID, paymentHash: ByteVector32, fsmFactory: FsmFactory = new FsmFactory): Behavior[Command] =
63+
def apply(nodeParams: NodeParams, parent: akka.actor.typed.ActorRef[NodeRelayer.Command], router: ActorRef, register: ActorRef, relayId: UUID, paymentHash: ByteVector32, fsmFactory: FsmFactory = new FsmFactory): Behavior[Command] =
6464
Behaviors.setup { context =>
6565
Behaviors.withMdc(Logs.mdc(
6666
category_opt = Some(Logs.LogCategory.PAYMENT),
6767
parentPaymentId_opt = Some(relayId), // for a node relay, we use the same identifier for the whole relay itself, and the outgoing payment
6868
paymentHash_opt = Some(paymentHash))) {
69-
new NodeRelay(nodeParams, router, register, relayId, paymentHash, context, fsmFactory)()
69+
new NodeRelay(nodeParams, parent, router, register, relayId, paymentHash, context, fsmFactory)()
7070
}
7171
}
7272

@@ -136,6 +136,7 @@ object NodeRelay {
136136
* see https://doc.akka.io/docs/akka/current/typed/style-guide.html#passing-around-too-many-parameters
137137
*/
138138
class NodeRelay private(nodeParams: NodeParams,
139+
parent: akka.actor.typed.ActorRef[NodeRelayer.Command],
139140
router: ActorRef,
140141
register: ActorRef,
141142
relayId: UUID,
@@ -164,7 +165,7 @@ class NodeRelay private(nodeParams: NodeParams,
164165
// TODO: @pm: maybe those checks should be done later in the flow (by the mpp FSM?)
165166
context.log.warn("rejecting htlcId={}: missing payment secret", add.id)
166167
rejectHtlc(add.id, add.channelId, add.amountMsat)
167-
Behaviors.stopped
168+
stopping()
168169
case Some(secret) =>
169170
import akka.actor.typed.scaladsl.adapter._
170171
context.log.info("relaying payment relayId={}", relayId)
@@ -205,15 +206,15 @@ class NodeRelay private(nodeParams: NodeParams,
205206
context.log.warn("could not complete incoming multi-part payment (parts={} paidAmount={} failure={})", parts.size, parts.map(_.amount).sum, failure)
206207
Metrics.recordPaymentRelayFailed(failure.getClass.getSimpleName, Tags.RelayType.Trampoline)
207208
parts.collect { case p: MultiPartPaymentFSM.HtlcPart => rejectHtlc(p.htlc.id, p.htlc.channelId, p.amount, Some(failure)) }
208-
Behaviors.stopped
209+
stopping()
209210
case WrappedMultiPartPaymentSucceeded(MultiPartPaymentFSM.MultiPartPaymentSucceeded(_, parts)) =>
210211
context.log.info("completed incoming multi-part payment with parts={} paidAmount={}", parts.size, parts.map(_.amount).sum)
211212
val upstream = Upstream.Trampoline(htlcs)
212213
validateRelay(nodeParams, upstream, nextPayload) match {
213214
case Some(failure) =>
214215
context.log.warn(s"rejecting trampoline payment reason=$failure")
215216
rejectPayment(upstream, Some(failure))
216-
Behaviors.stopped
217+
stopping()
217218
case None =>
218219
doSend(upstream, nextPayload, nextPacket)
219220
}
@@ -249,16 +250,27 @@ class NodeRelay private(nodeParams: NodeParams,
249250
case WrappedPaymentSent(paymentSent) =>
250251
context.log.debug("trampoline payment fully resolved downstream")
251252
success(upstream, fulfilledUpstream, paymentSent)
252-
Behaviors.stopped
253-
case WrappedPaymentFailed(PaymentFailed(id, _, failures, _)) =>
253+
stopping()
254+
case WrappedPaymentFailed(PaymentFailed(_, _, failures, _)) =>
254255
context.log.debug(s"trampoline payment failed downstream")
255256
if (!fulfilledUpstream) {
256257
rejectPayment(upstream, translateError(nodeParams, failures, upstream, nextPayload))
257258
}
258-
Behaviors.stopped
259+
stopping()
259260
}
260261
}
261262

263+
/**
264+
* Once the downstream payment is settled (fulfilled or failed), we reject new upstream payments while we wait for our parent to stop us.
265+
*/
266+
private def stopping(): Behavior[Command] = {
267+
parent ! NodeRelayer.RelayComplete(context.self, paymentHash)
268+
Behaviors.receiveMessagePartial {
269+
rejectExtraHtlcPartialFunction orElse {
270+
case Stop => Behaviors.stopped
271+
}
272+
}
273+
}
262274

263275
private def relay(upstream: Upstream.Trampoline, payloadOut: Onion.NodeRelayPayload, packetOut: OnionRoutingPacket): ActorRef = {
264276
val paymentCfg = SendPaymentConfig(relayId, relayId, None, paymentHash, payloadOut.amountToForward, payloadOut.outgoingNodeId, upstream, None, storeInDb = false, publishEvent = false, Nil)
@@ -297,11 +309,10 @@ class NodeRelay private(nodeParams: NodeParams,
297309
case Relay(nodeRelayPacket) =>
298310
rejectExtraHtlc(nodeRelayPacket.add)
299311
Behaviors.same
300-
// NB: this messages would be sent from the payment FSM which we stopped before going to this state, but all
301-
// this is asynchronous
312+
// NB: this message would be sent from the payment FSM which we stopped before going to this state, but all this is asynchronous.
302313
// We always fail extraneous HTLCs. They are a spec violation from the sender, but harmless in the relay case.
303314
// By failing them fast (before the payment has reached the final recipient) there's a good chance the sender won't lose any money.
304-
// We don't expect to relay pay-to-open payments
315+
// We don't expect to relay pay-to-open payments.
305316
case WrappedMultiPartExtraPaymentReceived(extraPaymentReceived) =>
306317
rejectExtraHtlc(extraPaymentReceived.payment.htlc)
307318
Behaviors.same

eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelayer.scala

+38-24
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,14 @@
1616

1717
package fr.acinq.eclair.payment.relay
1818

19-
import java.util.UUID
20-
21-
import akka.actor.ActorRef
22-
import akka.actor.typed.Behavior
2319
import akka.actor.typed.scaladsl.Behaviors
20+
import akka.actor.typed.{ActorRef, Behavior}
21+
import fr.acinq.bitcoin.ByteVector32
2422
import fr.acinq.eclair.payment._
2523
import fr.acinq.eclair.{Logs, NodeParams}
2624

25+
import java.util.UUID
26+
2727
/**
2828
* Created by t-bast on 10/10/2019.
2929
*/
@@ -38,33 +38,47 @@ object NodeRelayer {
3838
// @formatter:off
3939
sealed trait Command
4040
case class Relay(nodeRelayPacket: IncomingPacket.NodeRelayPacket) extends Command
41+
case class RelayComplete(childHandler: ActorRef[NodeRelay.Command], paymentHash: ByteVector32) extends Command
42+
private[relay] case class GetPendingPayments(replyTo: akka.actor.ActorRef) extends Command
4143
// @formatter:on
4244

4345
def mdc: Command => Map[String, String] = {
44-
case c: Relay => Logs.mdc(
45-
paymentHash_opt = Some(c.nodeRelayPacket.add.paymentHash))
46+
case c: Relay => Logs.mdc(paymentHash_opt = Some(c.nodeRelayPacket.add.paymentHash))
47+
case c: RelayComplete => Logs.mdc(paymentHash_opt = Some(c.paymentHash))
48+
case _: GetPendingPayments => Logs.mdc()
4649
}
4750

48-
def apply(nodeParams: NodeParams, router: ActorRef, register: ActorRef): Behavior[Command] =
51+
/**
52+
* @param children a map of current in-process payments, indexed by payment hash and purposefully *not* by payment id,
53+
* because that is how we aggregate payment parts (when the incoming payment uses MPP).
54+
*/
55+
def apply(nodeParams: NodeParams, router: akka.actor.ActorRef, register: akka.actor.ActorRef, children: Map[ByteVector32, ActorRef[NodeRelay.Command]] = Map.empty): Behavior[Command] =
4956
Behaviors.setup { context =>
5057
Behaviors.withMdc(Logs.mdc(category_opt = Some(Logs.LogCategory.PAYMENT)), mdc) {
51-
Behaviors.receiveMessage {
52-
case Relay(nodeRelayPacket) =>
53-
import nodeRelayPacket.add.paymentHash
54-
val handler = context.child(paymentHash.toString) match {
55-
case Some(handler) =>
56-
// NB: we could also maintain a list of children
57-
handler.unsafeUpcast[NodeRelay.Command] // we know that all children are of type NodeRelay
58-
case None =>
59-
val relayId = UUID.randomUUID()
60-
context.log.debug(s"spawning a new handler with relayId=$relayId")
61-
// we index children by paymentHash, not relayId, because there is no concept of individual payment on LN
62-
context.spawn(NodeRelay.apply(nodeParams, router, register, relayId, paymentHash), name = paymentHash.toString)
63-
}
64-
context.log.debug("forwarding incoming htlc to handler")
65-
handler ! NodeRelay.Relay(nodeRelayPacket)
66-
Behaviors.same
67-
}
58+
Behaviors.receiveMessage {
59+
case Relay(nodeRelayPacket) =>
60+
import nodeRelayPacket.add.paymentHash
61+
children.get(paymentHash) match {
62+
case Some(handler) =>
63+
context.log.debug("forwarding incoming htlc to existing handler")
64+
handler ! NodeRelay.Relay(nodeRelayPacket)
65+
Behaviors.same
66+
case None =>
67+
val relayId = UUID.randomUUID()
68+
context.log.debug(s"spawning a new handler with relayId=$relayId")
69+
val handler = context.spawn(NodeRelay.apply(nodeParams, context.self, router, register, relayId, paymentHash), relayId.toString)
70+
context.log.debug("forwarding incoming htlc to new handler")
71+
handler ! NodeRelay.Relay(nodeRelayPacket)
72+
apply(nodeParams, router, register, children + (paymentHash -> handler))
73+
}
74+
case RelayComplete(childHandler, paymentHash) =>
75+
// we do a back-and-forth between parent and child before stopping the child to prevent a race condition
76+
childHandler ! NodeRelay.Stop
77+
apply(nodeParams, router, register, children - paymentHash)
78+
case GetPendingPayments(replyTo) =>
79+
replyTo ! children
80+
Behaviors.same
81+
}
6882
}
6983
}
7084
}

eclair-core/src/test/scala/fr/acinq/eclair/payment/relay/NodeRelayerSpec.scala

+55-4
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import akka.actor.typed.eventstream.EventStream
2323
import akka.actor.typed.scaladsl.ActorContext
2424
import akka.actor.typed.scaladsl.adapter._
2525
import com.typesafe.config.ConfigFactory
26-
import fr.acinq.bitcoin.{Block, Crypto}
26+
import fr.acinq.bitcoin.{Block, ByteVector32, Crypto}
2727
import fr.acinq.eclair.Features.{BasicMultiPartPayment, PaymentSecret, VariableLengthOnion}
2828
import fr.acinq.eclair.channel.{CMD_FAIL_HTLC, CMD_FULFILL_HTLC, Register}
2929
import fr.acinq.eclair.crypto.Sphinx
@@ -53,10 +53,11 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl
5353

5454
import NodeRelayerSpec._
5555

56-
case class FixtureParam(nodeParams: NodeParams, nodeRelayer: ActorRef[NodeRelay.Command], router: TestProbe[Any], register: TestProbe[Any], mockPayFSM: TestProbe[Any], eventListener: TestProbe[PaymentEvent])
56+
case class FixtureParam(nodeParams: NodeParams, nodeRelayer: ActorRef[NodeRelay.Command], parent: TestProbe[NodeRelayer.Command], router: TestProbe[Any], register: TestProbe[Any], mockPayFSM: TestProbe[Any], eventListener: TestProbe[PaymentEvent])
5757

5858
override def withFixture(test: OneArgTest): Outcome = {
5959
val nodeParams = TestConstants.Bob.nodeParams.copy(multiPartPaymentExpiry = 5 seconds)
60+
val parent = TestProbe[NodeRelayer.Command]("parent-relayer")
6061
val router = TestProbe[Any]("router")
6162
val register = TestProbe[Any]("register")
6263
val eventListener = TestProbe[PaymentEvent]("event-listener")
@@ -78,8 +79,53 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl
7879
}
7980
}
8081
}
81-
val nodeRelay = testKit.spawn(NodeRelay(nodeParams, router.ref.toClassic, register.ref.toClassic, relayId, paymentHash, fsmFactory))
82-
withFixture(test.toNoArgTest(FixtureParam(nodeParams, nodeRelay, router, register, mockPayFSM, eventListener)))
82+
val nodeRelay = testKit.spawn(NodeRelay(nodeParams, parent.ref, router.ref.toClassic, register.ref.toClassic, relayId, paymentHash, fsmFactory))
83+
withFixture(test.toNoArgTest(FixtureParam(nodeParams, nodeRelay, parent, router, register, mockPayFSM, eventListener)))
84+
}
85+
86+
test("stop child handler when relay is complete") { f =>
87+
import f._
88+
val probe = TestProbe[Any]
89+
90+
{
91+
val parentRelayer = testKit.spawn(NodeRelayer(nodeParams, router.ref.toClassic, register.ref.toClassic))
92+
parentRelayer ! NodeRelayer.GetPendingPayments(probe.ref.toClassic)
93+
probe.expectMessage(Map.empty)
94+
}
95+
{
96+
val (paymentHash1, child1) = (randomBytes32, TestProbe[NodeRelay.Command])
97+
val (paymentHash2, child2) = (randomBytes32, TestProbe[NodeRelay.Command])
98+
val children = Map(paymentHash1 -> child1.ref, paymentHash2 -> child2.ref)
99+
val parentRelayer = testKit.spawn(NodeRelayer(nodeParams, router.ref.toClassic, register.ref.toClassic, children))
100+
parentRelayer ! NodeRelayer.GetPendingPayments(probe.ref.toClassic)
101+
probe.expectMessage(children)
102+
103+
parentRelayer ! NodeRelayer.RelayComplete(child1.ref, paymentHash1)
104+
child1.expectMessage(NodeRelay.Stop)
105+
parentRelayer ! NodeRelayer.RelayComplete(child1.ref, paymentHash1)
106+
child1.expectMessage(NodeRelay.Stop)
107+
parentRelayer ! NodeRelayer.GetPendingPayments(probe.ref.toClassic)
108+
probe.expectMessage(children - paymentHash1)
109+
}
110+
{
111+
val parentRelayer = testKit.spawn(NodeRelayer(nodeParams, router.ref.toClassic, register.ref.toClassic))
112+
parentRelayer ! NodeRelayer.Relay(incomingMultiPart.head)
113+
parentRelayer ! NodeRelayer.GetPendingPayments(probe.ref.toClassic)
114+
val pending1 = probe.expectMessageType[Map[ByteVector32, ActorRef[NodeRelay.Command]]]
115+
assert(pending1.size === 1)
116+
assert(pending1.head._1 === paymentHash)
117+
118+
parentRelayer ! NodeRelayer.RelayComplete(pending1.head._2, paymentHash)
119+
parentRelayer ! NodeRelayer.GetPendingPayments(probe.ref.toClassic)
120+
probe.expectMessage(Map.empty)
121+
122+
parentRelayer ! NodeRelayer.Relay(incomingMultiPart.head)
123+
parentRelayer ! NodeRelayer.GetPendingPayments(probe.ref.toClassic)
124+
val pending2 = probe.expectMessageType[Map[ByteVector32, ActorRef[NodeRelay.Command]]]
125+
assert(pending2.size === 1)
126+
assert(pending2.head._1 === paymentHash)
127+
assert(pending2.head._2 !== pending1.head._2)
128+
}
83129
}
84130

85131
test("fail to relay when incoming multi-part payment times out") { f =>
@@ -95,6 +141,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl
95141
assert(fwd.message === CMD_FAIL_HTLC(p.add.id, failure, commit = true))
96142
}
97143

144+
parent.expectMessageType[NodeRelayer.RelayComplete]
98145
register.expectNoMessage(100 millis)
99146
}
100147

@@ -398,6 +445,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl
398445
validateRelayEvent(relayEvent)
399446
assert(relayEvent.incoming.toSet === incomingMultiPart.map(i => PaymentRelayed.Part(i.add.amountMsat, i.add.channelId)).toSet)
400447
assert(relayEvent.outgoing.nonEmpty)
448+
parent.expectMessageType[NodeRelayer.RelayComplete]
401449
register.expectNoMessage(100 millis)
402450
}
403451

@@ -425,6 +473,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl
425473
validateRelayEvent(relayEvent)
426474
assert(relayEvent.incoming === Seq(PaymentRelayed.Part(incomingSinglePart.add.amountMsat, incomingSinglePart.add.channelId)))
427475
assert(relayEvent.outgoing.nonEmpty)
476+
parent.expectMessageType[NodeRelayer.RelayComplete]
428477
register.expectNoMessage(100 millis)
429478
}
430479

@@ -464,6 +513,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl
464513
validateRelayEvent(relayEvent)
465514
assert(relayEvent.incoming === incomingMultiPart.map(i => PaymentRelayed.Part(i.add.amountMsat, i.add.channelId)))
466515
assert(relayEvent.outgoing.nonEmpty)
516+
parent.expectMessageType[NodeRelayer.RelayComplete]
467517
register.expectNoMessage(100 millis)
468518
}
469519

@@ -500,6 +550,7 @@ class NodeRelayerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("appl
500550
validateRelayEvent(relayEvent)
501551
assert(relayEvent.incoming === incomingMultiPart.map(i => PaymentRelayed.Part(i.add.amountMsat, i.add.channelId)))
502552
assert(relayEvent.outgoing.length === 1)
553+
parent.expectMessageType[NodeRelayer.RelayComplete]
503554
register.expectNoMessage(100 millis)
504555
}
505556

0 commit comments

Comments
 (0)