@@ -29,11 +29,10 @@ import fr.acinq.eclair.payment.OutgoingPacket.Upstream
29
29
import fr .acinq .eclair .payment ._
30
30
import fr .acinq .eclair .payment .receive .MultiPartPaymentFSM
31
31
import fr .acinq .eclair .payment .receive .MultiPartPaymentFSM .HtlcPart
32
- import fr .acinq .eclair .payment .relay .NodeRelay .FsmFactory
33
32
import fr .acinq .eclair .payment .send .MultiPartPaymentLifecycle .{PreimageReceived , SendMultiPartPayment }
34
33
import fr .acinq .eclair .payment .send .PaymentInitiator .SendPaymentConfig
35
34
import fr .acinq .eclair .payment .send .PaymentLifecycle .SendPayment
36
- import fr .acinq .eclair .payment .send .{MultiPartPaymentLifecycle , PaymentLifecycle }
35
+ import fr .acinq .eclair .payment .send .{MultiPartPaymentLifecycle , PaymentInitiator , PaymentLifecycle }
37
36
import fr .acinq .eclair .router .Router .RouteParams
38
37
import fr .acinq .eclair .router .{BalanceTooLow , RouteCalculation , RouteNotFound }
39
38
import fr .acinq .eclair .wire .protocol ._
@@ -60,29 +59,32 @@ object NodeRelay {
60
59
private case class WrappedPaymentFailed (paymentFailed : PaymentFailed ) extends Command
61
60
// @formatter:on
62
61
63
- def apply (nodeParams : NodeParams , parent : akka.actor.typed.ActorRef [NodeRelayer .Command ], router : ActorRef , register : ActorRef , relayId : UUID , paymentHash : ByteVector32 , fsmFactory : FsmFactory = new FsmFactory ): Behavior [Command ] =
64
- Behaviors .setup { context =>
65
- Behaviors .withMdc(Logs .mdc(
66
- category_opt = Some (Logs .LogCategory .PAYMENT ),
67
- parentPaymentId_opt = Some (relayId), // for a node relay, we use the same identifier for the whole relay itself, and the outgoing payment
68
- paymentHash_opt = Some (paymentHash))) {
69
- new NodeRelay (nodeParams, parent, router, register, relayId, paymentHash, context, fsmFactory)()
70
- }
71
- }
62
+ trait OutgoingPaymentFactory {
63
+ def spawnOutgoingPayFSM (context : ActorContext [NodeRelay .Command ], cfg : SendPaymentConfig , multiPart : Boolean ): ActorRef
64
+ }
72
65
73
- /**
74
- * This is supposed to be overridden in tests
75
- */
76
- class FsmFactory {
77
- def spawnOutgoingPayFSM (context : ActorContext [NodeRelay .Command ], nodeParams : NodeParams , router : ActorRef , register : ActorRef , cfg : SendPaymentConfig , multiPart : Boolean ): ActorRef = {
66
+ case class SimpleOutgoingPaymentFactory (nodeParams : NodeParams , router : ActorRef , register : ActorRef ) extends OutgoingPaymentFactory {
67
+ val paymentFactory = PaymentInitiator .SimplePaymentFactory (nodeParams, router, register)
68
+
69
+ override def spawnOutgoingPayFSM (context : ActorContext [Command ], cfg : SendPaymentConfig , multiPart : Boolean ): ActorRef = {
78
70
if (multiPart) {
79
- context.toClassic.actorOf(MultiPartPaymentLifecycle .props(nodeParams, cfg, router, register ))
71
+ context.toClassic.actorOf(MultiPartPaymentLifecycle .props(nodeParams, cfg, router, paymentFactory ))
80
72
} else {
81
73
context.toClassic.actorOf(PaymentLifecycle .props(nodeParams, cfg, router, register))
82
74
}
83
75
}
84
76
}
85
77
78
+ def apply (nodeParams : NodeParams , parent : akka.actor.typed.ActorRef [NodeRelayer .Command ], register : ActorRef , relayId : UUID , paymentHash : ByteVector32 , outgoingPaymentFactory : OutgoingPaymentFactory ): Behavior [Command ] =
79
+ Behaviors .setup { context =>
80
+ Behaviors .withMdc(Logs .mdc(
81
+ category_opt = Some (Logs .LogCategory .PAYMENT ),
82
+ parentPaymentId_opt = Some (relayId), // for a node relay, we use the same identifier for the whole relay itself, and the outgoing payment
83
+ paymentHash_opt = Some (paymentHash))) {
84
+ new NodeRelay (nodeParams, parent, register, relayId, paymentHash, context, outgoingPaymentFactory)()
85
+ }
86
+ }
87
+
86
88
def validateRelay (nodeParams : NodeParams , upstream : Upstream .Trampoline , payloadOut : Onion .NodeRelayPayload ): Option [FailureMessage ] = {
87
89
val fee = nodeFee(nodeParams.feeBase, nodeParams.feeProportionalMillionth, payloadOut.amountToForward)
88
90
if (upstream.amountIn - payloadOut.amountToForward < fee) {
@@ -139,12 +141,11 @@ object NodeRelay {
139
141
*/
140
142
class NodeRelay private (nodeParams : NodeParams ,
141
143
parent : akka.actor.typed.ActorRef [NodeRelayer .Command ],
142
- router : ActorRef ,
143
144
register : ActorRef ,
144
145
relayId : UUID ,
145
146
paymentHash : ByteVector32 ,
146
147
context : ActorContext [NodeRelay .Command ],
147
- fsmFactory : FsmFactory ) {
148
+ outgoingPaymentFactory : NodeRelay . OutgoingPaymentFactory ) {
148
149
149
150
import NodeRelay ._
150
151
@@ -285,20 +286,20 @@ class NodeRelay private(nodeParams: NodeParams,
285
286
case Some (paymentSecret) if Features (features).hasFeature(Features .BasicMultiPartPayment ) =>
286
287
context.log.debug(" sending the payment to non-trampoline recipient using MPP" )
287
288
val payment = SendMultiPartPayment (payFsmAdapters, paymentSecret, payloadOut.outgoingNodeId, payloadOut.amountToForward, payloadOut.outgoingCltv, nodeParams.maxPaymentAttempts, routingHints, Some (routeParams))
288
- val payFSM = fsmFactory .spawnOutgoingPayFSM(context, nodeParams, router, register , paymentCfg, multiPart = true )
289
+ val payFSM = outgoingPaymentFactory .spawnOutgoingPayFSM(context, paymentCfg, multiPart = true )
289
290
payFSM ! payment
290
291
payFSM
291
292
case _ =>
292
293
context.log.debug(" sending the payment to non-trampoline recipient without MPP" )
293
294
val finalPayload = Onion .createSinglePartPayload(payloadOut.amountToForward, payloadOut.outgoingCltv, payloadOut.paymentSecret)
294
295
val payment = SendPayment (payFsmAdapters, payloadOut.outgoingNodeId, finalPayload, nodeParams.maxPaymentAttempts, routingHints, Some (routeParams))
295
- val payFSM = fsmFactory .spawnOutgoingPayFSM(context, nodeParams, router, register , paymentCfg, multiPart = false )
296
+ val payFSM = outgoingPaymentFactory .spawnOutgoingPayFSM(context, paymentCfg, multiPart = false )
296
297
payFSM ! payment
297
298
payFSM
298
299
}
299
300
case None =>
300
301
context.log.debug(" sending the payment to the next trampoline node" )
301
- val payFSM = fsmFactory .spawnOutgoingPayFSM(context, nodeParams, router, register , paymentCfg, multiPart = true )
302
+ val payFSM = outgoingPaymentFactory .spawnOutgoingPayFSM(context, paymentCfg, multiPart = true )
302
303
val paymentSecret = randomBytes32 // we generate a new secret to protect against probing attacks
303
304
val payment = SendMultiPartPayment (payFsmAdapters, paymentSecret, payloadOut.outgoingNodeId, payloadOut.amountToForward, payloadOut.outgoingCltv, nodeParams.maxPaymentAttempts, routeParams = Some (routeParams), additionalTlvs = Seq (OnionTlv .TrampolineOnion (packetOut)))
304
305
payFSM ! payment
0 commit comments