Skip to content

Commit c6a76af

Browse files
authored
Introduce actor factories (#1744)
This removes unnecessary fields and allows more flexibility in tests.
1 parent e5429eb commit c6a76af

12 files changed

+195
-136
lines changed

eclair-core/src/main/scala/fr/acinq/eclair/Setup.scala

+9-5
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ import fr.acinq.eclair.channel.Register
3737
import fr.acinq.eclair.crypto.keymanager.{LocalChannelKeyManager, LocalNodeKeyManager}
3838
import fr.acinq.eclair.db.Databases.FileBackup
3939
import fr.acinq.eclair.db.{Databases, DbEventHandler, FileBackupHandler}
40-
import fr.acinq.eclair.io.{ClientSpawner, Server, Switchboard}
40+
import fr.acinq.eclair.io.{ClientSpawner, Peer, Server, Switchboard}
4141
import fr.acinq.eclair.payment.receive.PaymentHandler
4242
import fr.acinq.eclair.payment.relay.Relayer
4343
import fr.acinq.eclair.payment.send.{Autoprobe, PaymentInitiator}
@@ -290,8 +290,8 @@ class Setup(datadir: File,
290290
new ElectrumEclairWallet(electrumWallet, nodeParams.chainHash)
291291
}
292292
_ = wallet.getReceiveAddress.map(address => logger.info(s"initial wallet address=$address"))
293-
// do not change the name of this actor. it is used in the configuration to specify a custom bounded mailbox
294293

294+
// do not change the name of this actor. it is used in the configuration to specify a custom bounded mailbox
295295
backupHandler = if (config.getBoolean("enable-db-backup")) {
296296
nodeParams.db match {
297297
case fileBackup: FileBackup => system.actorOf(SimpleSupervisor.props(
@@ -314,10 +314,14 @@ class Setup(datadir: File,
314314
// Before initializing the switchboard (which re-connects us to the network) and the user-facing parts of the system,
315315
// we want to make sure the handler for post-restart broken HTLCs has finished initializing.
316316
_ <- postRestartCleanUpInitialized.future
317-
switchboard = system.actorOf(SimpleSupervisor.props(Switchboard.props(nodeParams, watcher, relayer, wallet), "switchboard", SupervisorStrategy.Resume))
317+
318+
channelFactory = Peer.SimpleChannelFactory(nodeParams, watcher, relayer, wallet)
319+
peerFactory = Switchboard.SimplePeerFactory(nodeParams, wallet, channelFactory)
320+
321+
switchboard = system.actorOf(SimpleSupervisor.props(Switchboard.props(nodeParams, peerFactory), "switchboard", SupervisorStrategy.Resume))
318322
clientSpawner = system.actorOf(SimpleSupervisor.props(ClientSpawner.props(nodeParams.keyPair, nodeParams.socksProxy_opt, nodeParams.peerConnectionConf, switchboard, router), "client-spawner", SupervisorStrategy.Restart))
319323
server = system.actorOf(SimpleSupervisor.props(Server.props(nodeParams.keyPair, nodeParams.peerConnectionConf, switchboard, router, serverBindingAddress, Some(tcpBound)), "server", SupervisorStrategy.Restart))
320-
paymentInitiator = system.actorOf(SimpleSupervisor.props(PaymentInitiator.props(nodeParams, router, register), "payment-initiator", SupervisorStrategy.Restart))
324+
paymentInitiator = system.actorOf(SimpleSupervisor.props(PaymentInitiator.props(nodeParams, PaymentInitiator.SimplePaymentFactory(nodeParams, router, register)), "payment-initiator", SupervisorStrategy.Restart))
321325
_ = for (i <- 0 until config.getInt("autoprobe-count")) yield system.actorOf(SimpleSupervisor.props(Autoprobe.props(nodeParams, router, paymentInitiator), s"payment-autoprobe-$i", SupervisorStrategy.Restart))
322326

323327
kit = Kit(
@@ -381,11 +385,11 @@ class Setup(datadir: File,
381385

382386
}
383387

388+
// @formatter:off
384389
object Setup {
385390
final case class Seeds(nodeSeed: ByteVector, channelSeed: ByteVector)
386391
}
387392

388-
// @formatter:off
389393
sealed trait Bitcoin
390394
case class Bitcoind(bitcoinClient: BasicBitcoinJsonRPCClient) extends Bitcoin
391395
case class Electrum(electrumClient: ActorRef) extends Bitcoin

eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala

+16-7
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
package fr.acinq.eclair.io
1818

19-
import akka.actor.{Actor, ActorRef, ExtendedActorSystem, FSM, OneForOneStrategy, PossiblyHarmful, Props, Status, SupervisorStrategy, Terminated}
19+
import akka.actor.{Actor, ActorContext, ActorRef, ExtendedActorSystem, FSM, OneForOneStrategy, PossiblyHarmful, Props, Status, SupervisorStrategy, Terminated}
2020
import akka.event.Logging.MDC
2121
import akka.event.{BusLogging, DiagnosticLoggingAdapter}
2222
import akka.util.Timeout
@@ -48,7 +48,7 @@ import java.net.InetSocketAddress
4848
*
4949
* Created by PM on 26/08/2016.
5050
*/
51-
class Peer(val nodeParams: NodeParams, remoteNodeId: PublicKey, watcher: ActorRef, relayer: ActorRef, wallet: EclairWallet) extends FSMDiagnosticActorLogging[Peer.State, Peer.Data] {
51+
class Peer(val nodeParams: NodeParams, remoteNodeId: PublicKey, wallet: EclairWallet, channelFactory: Peer.ChannelFactory) extends FSMDiagnosticActorLogging[Peer.State, Peer.Data] {
5252

5353
import Peer._
5454

@@ -57,7 +57,7 @@ class Peer(val nodeParams: NodeParams, remoteNodeId: PublicKey, watcher: ActorRe
5757
when(INSTANTIATING) {
5858
case Event(Init(storedChannels), _) =>
5959
val channels = storedChannels.map { state =>
60-
val channel = spawnChannel(nodeParams, origin_opt = None)
60+
val channel = spawnChannel(origin_opt = None)
6161
channel ! INPUT_RESTORED(state)
6262
FinalChannelId(state.channelId) -> channel
6363
}.toMap
@@ -294,12 +294,12 @@ class Peer(val nodeParams: NodeParams, remoteNodeId: PublicKey, watcher: ActorRe
294294
(Helpers.getFinalScriptPubKey(wallet, nodeParams.chainHash), None)
295295
}
296296
val localParams = makeChannelParams(nodeParams, features, finalScript, walletStaticPaymentBasepoint, funder, fundingAmount)
297-
val channel = spawnChannel(nodeParams, origin_opt)
297+
val channel = spawnChannel(origin_opt)
298298
(channel, localParams)
299299
}
300300

301-
def spawnChannel(nodeParams: NodeParams, origin_opt: Option[ActorRef]): ActorRef = {
302-
val channel = context.actorOf(Channel.props(nodeParams, wallet, remoteNodeId, watcher, relayer, origin_opt))
301+
def spawnChannel(origin_opt: Option[ActorRef]): ActorRef = {
302+
val channel = channelFactory.spawn(context, remoteNodeId, origin_opt)
303303
context watch channel
304304
channel
305305
}
@@ -353,7 +353,16 @@ object Peer {
353353
val UNKNOWN_CHANNEL_MESSAGE: ByteVector = ByteVector.view("unknown channel".getBytes())
354354
// @formatter:on
355355

356-
def props(nodeParams: NodeParams, remoteNodeId: PublicKey, watcher: ActorRef, relayer: ActorRef, wallet: EclairWallet): Props = Props(new Peer(nodeParams, remoteNodeId, watcher, relayer: ActorRef, wallet))
356+
trait ChannelFactory {
357+
def spawn(context: ActorContext, remoteNodeId: PublicKey, origin_opt: Option[ActorRef]): ActorRef
358+
}
359+
360+
case class SimpleChannelFactory(nodeParams: NodeParams, watcher: ActorRef, relayer: ActorRef, wallet: EclairWallet) extends ChannelFactory {
361+
override def spawn(context: ActorContext, remoteNodeId: PublicKey, origin_opt: Option[ActorRef]): ActorRef =
362+
context.actorOf(Channel.props(nodeParams, wallet, remoteNodeId, watcher, relayer, origin_opt))
363+
}
364+
365+
def props(nodeParams: NodeParams, remoteNodeId: PublicKey, wallet: EclairWallet, channelFactory: ChannelFactory): Props = Props(new Peer(nodeParams, remoteNodeId, wallet, channelFactory))
357366

358367
// @formatter:off
359368

eclair-core/src/main/scala/fr/acinq/eclair/io/Switchboard.scala

+13-4
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
package fr.acinq.eclair.io
1818

19-
import akka.actor.{Actor, ActorLogging, ActorRef, OneForOneStrategy, Props, Status, SupervisorStrategy}
19+
import akka.actor.{Actor, ActorContext, ActorLogging, ActorRef, OneForOneStrategy, Props, Status, SupervisorStrategy}
2020
import fr.acinq.bitcoin.Crypto.PublicKey
2121
import fr.acinq.eclair.NodeParams
2222
import fr.acinq.eclair.blockchain.EclairWallet
@@ -29,7 +29,7 @@ import fr.acinq.eclair.router.Router.RouterConf
2929
* Ties network connections to peers.
3030
* Created by PM on 14/02/2017.
3131
*/
32-
class Switchboard(nodeParams: NodeParams, watcher: ActorRef, relayer: ActorRef, wallet: EclairWallet) extends Actor with ActorLogging {
32+
class Switchboard(nodeParams: NodeParams, peerFactory: Switchboard.PeerFactory) extends Actor with ActorLogging {
3333

3434
import Switchboard._
3535

@@ -103,7 +103,7 @@ class Switchboard(nodeParams: NodeParams, watcher: ActorRef, relayer: ActorRef,
103103
*/
104104
def getPeer(remoteNodeId: PublicKey): Option[ActorRef] = context.child(peerActorName(remoteNodeId))
105105

106-
def createPeer(remoteNodeId: PublicKey): ActorRef = context.actorOf(Peer.props(nodeParams, remoteNodeId, watcher, relayer, wallet), name = peerActorName(remoteNodeId))
106+
def createPeer(remoteNodeId: PublicKey): ActorRef = peerFactory.spawn(context, remoteNodeId)
107107

108108
def createOrGetPeer(remoteNodeId: PublicKey, offlineChannels: Set[HasCommitments]): ActorRef = {
109109
getPeer(remoteNodeId) match {
@@ -124,7 +124,16 @@ class Switchboard(nodeParams: NodeParams, watcher: ActorRef, relayer: ActorRef,
124124

125125
object Switchboard {
126126

127-
def props(nodeParams: NodeParams, watcher: ActorRef, relayer: ActorRef, wallet: EclairWallet) = Props(new Switchboard(nodeParams, watcher, relayer, wallet))
127+
trait PeerFactory {
128+
def spawn(context: ActorContext, remoteNodeId: PublicKey): ActorRef
129+
}
130+
131+
case class SimplePeerFactory(nodeParams: NodeParams, wallet: EclairWallet, channelFactory: Peer.ChannelFactory) extends PeerFactory {
132+
override def spawn(context: ActorContext, remoteNodeId: PublicKey): ActorRef =
133+
context.actorOf(Peer.props(nodeParams, remoteNodeId, wallet, channelFactory), name = peerActorName(remoteNodeId))
134+
}
135+
136+
def props(nodeParams: NodeParams, peerFactory: PeerFactory) = Props(new Switchboard(nodeParams, peerFactory))
128137

129138
def peerActorName(remoteNodeId: PublicKey): String = s"peer-$remoteNodeId"
130139

eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala

+23-22
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,10 @@ import fr.acinq.eclair.payment.OutgoingPacket.Upstream
2929
import fr.acinq.eclair.payment._
3030
import fr.acinq.eclair.payment.receive.MultiPartPaymentFSM
3131
import fr.acinq.eclair.payment.receive.MultiPartPaymentFSM.HtlcPart
32-
import fr.acinq.eclair.payment.relay.NodeRelay.FsmFactory
3332
import fr.acinq.eclair.payment.send.MultiPartPaymentLifecycle.{PreimageReceived, SendMultiPartPayment}
3433
import fr.acinq.eclair.payment.send.PaymentInitiator.SendPaymentConfig
3534
import fr.acinq.eclair.payment.send.PaymentLifecycle.SendPayment
36-
import fr.acinq.eclair.payment.send.{MultiPartPaymentLifecycle, PaymentLifecycle}
35+
import fr.acinq.eclair.payment.send.{MultiPartPaymentLifecycle, PaymentInitiator, PaymentLifecycle}
3736
import fr.acinq.eclair.router.Router.RouteParams
3837
import fr.acinq.eclair.router.{BalanceTooLow, RouteCalculation, RouteNotFound}
3938
import fr.acinq.eclair.wire.protocol._
@@ -60,29 +59,32 @@ object NodeRelay {
6059
private case class WrappedPaymentFailed(paymentFailed: PaymentFailed) extends Command
6160
// @formatter:on
6261

63-
def apply(nodeParams: NodeParams, parent: akka.actor.typed.ActorRef[NodeRelayer.Command], router: ActorRef, register: ActorRef, relayId: UUID, paymentHash: ByteVector32, fsmFactory: FsmFactory = new FsmFactory): Behavior[Command] =
64-
Behaviors.setup { context =>
65-
Behaviors.withMdc(Logs.mdc(
66-
category_opt = Some(Logs.LogCategory.PAYMENT),
67-
parentPaymentId_opt = Some(relayId), // for a node relay, we use the same identifier for the whole relay itself, and the outgoing payment
68-
paymentHash_opt = Some(paymentHash))) {
69-
new NodeRelay(nodeParams, parent, router, register, relayId, paymentHash, context, fsmFactory)()
70-
}
71-
}
62+
trait OutgoingPaymentFactory {
63+
def spawnOutgoingPayFSM(context: ActorContext[NodeRelay.Command], cfg: SendPaymentConfig, multiPart: Boolean): ActorRef
64+
}
7265

73-
/**
74-
* This is supposed to be overridden in tests
75-
*/
76-
class FsmFactory {
77-
def spawnOutgoingPayFSM(context: ActorContext[NodeRelay.Command], nodeParams: NodeParams, router: ActorRef, register: ActorRef, cfg: SendPaymentConfig, multiPart: Boolean): ActorRef = {
66+
case class SimpleOutgoingPaymentFactory(nodeParams: NodeParams, router: ActorRef, register: ActorRef) extends OutgoingPaymentFactory {
67+
val paymentFactory = PaymentInitiator.SimplePaymentFactory(nodeParams, router, register)
68+
69+
override def spawnOutgoingPayFSM(context: ActorContext[Command], cfg: SendPaymentConfig, multiPart: Boolean): ActorRef = {
7870
if (multiPart) {
79-
context.toClassic.actorOf(MultiPartPaymentLifecycle.props(nodeParams, cfg, router, register))
71+
context.toClassic.actorOf(MultiPartPaymentLifecycle.props(nodeParams, cfg, router, paymentFactory))
8072
} else {
8173
context.toClassic.actorOf(PaymentLifecycle.props(nodeParams, cfg, router, register))
8274
}
8375
}
8476
}
8577

78+
def apply(nodeParams: NodeParams, parent: akka.actor.typed.ActorRef[NodeRelayer.Command], register: ActorRef, relayId: UUID, paymentHash: ByteVector32, outgoingPaymentFactory: OutgoingPaymentFactory): Behavior[Command] =
79+
Behaviors.setup { context =>
80+
Behaviors.withMdc(Logs.mdc(
81+
category_opt = Some(Logs.LogCategory.PAYMENT),
82+
parentPaymentId_opt = Some(relayId), // for a node relay, we use the same identifier for the whole relay itself, and the outgoing payment
83+
paymentHash_opt = Some(paymentHash))) {
84+
new NodeRelay(nodeParams, parent, register, relayId, paymentHash, context, outgoingPaymentFactory)()
85+
}
86+
}
87+
8688
def validateRelay(nodeParams: NodeParams, upstream: Upstream.Trampoline, payloadOut: Onion.NodeRelayPayload): Option[FailureMessage] = {
8789
val fee = nodeFee(nodeParams.feeBase, nodeParams.feeProportionalMillionth, payloadOut.amountToForward)
8890
if (upstream.amountIn - payloadOut.amountToForward < fee) {
@@ -139,12 +141,11 @@ object NodeRelay {
139141
*/
140142
class NodeRelay private(nodeParams: NodeParams,
141143
parent: akka.actor.typed.ActorRef[NodeRelayer.Command],
142-
router: ActorRef,
143144
register: ActorRef,
144145
relayId: UUID,
145146
paymentHash: ByteVector32,
146147
context: ActorContext[NodeRelay.Command],
147-
fsmFactory: FsmFactory) {
148+
outgoingPaymentFactory: NodeRelay.OutgoingPaymentFactory) {
148149

149150
import NodeRelay._
150151

@@ -285,20 +286,20 @@ class NodeRelay private(nodeParams: NodeParams,
285286
case Some(paymentSecret) if Features(features).hasFeature(Features.BasicMultiPartPayment) =>
286287
context.log.debug("sending the payment to non-trampoline recipient using MPP")
287288
val payment = SendMultiPartPayment(payFsmAdapters, paymentSecret, payloadOut.outgoingNodeId, payloadOut.amountToForward, payloadOut.outgoingCltv, nodeParams.maxPaymentAttempts, routingHints, Some(routeParams))
288-
val payFSM = fsmFactory.spawnOutgoingPayFSM(context, nodeParams, router, register, paymentCfg, multiPart = true)
289+
val payFSM = outgoingPaymentFactory.spawnOutgoingPayFSM(context, paymentCfg, multiPart = true)
289290
payFSM ! payment
290291
payFSM
291292
case _ =>
292293
context.log.debug("sending the payment to non-trampoline recipient without MPP")
293294
val finalPayload = Onion.createSinglePartPayload(payloadOut.amountToForward, payloadOut.outgoingCltv, payloadOut.paymentSecret)
294295
val payment = SendPayment(payFsmAdapters, payloadOut.outgoingNodeId, finalPayload, nodeParams.maxPaymentAttempts, routingHints, Some(routeParams))
295-
val payFSM = fsmFactory.spawnOutgoingPayFSM(context, nodeParams, router, register, paymentCfg, multiPart = false)
296+
val payFSM = outgoingPaymentFactory.spawnOutgoingPayFSM(context, paymentCfg, multiPart = false)
296297
payFSM ! payment
297298
payFSM
298299
}
299300
case None =>
300301
context.log.debug("sending the payment to the next trampoline node")
301-
val payFSM = fsmFactory.spawnOutgoingPayFSM(context, nodeParams, router, register, paymentCfg, multiPart = true)
302+
val payFSM = outgoingPaymentFactory.spawnOutgoingPayFSM(context, paymentCfg, multiPart = true)
302303
val paymentSecret = randomBytes32 // we generate a new secret to protect against probing attacks
303304
val payment = SendMultiPartPayment(payFsmAdapters, paymentSecret, payloadOut.outgoingNodeId, payloadOut.amountToForward, payloadOut.outgoingCltv, nodeParams.maxPaymentAttempts, routeParams = Some(routeParams), additionalTlvs = Seq(OnionTlv.TrampolineOnion(packetOut)))
304305
payFSM ! payment

eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelayer.scala

+2-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,8 @@ object NodeRelayer {
6666
case None =>
6767
val relayId = UUID.randomUUID()
6868
context.log.debug(s"spawning a new handler with relayId=$relayId")
69-
val handler = context.spawn(NodeRelay.apply(nodeParams, context.self, router, register, relayId, paymentHash), relayId.toString)
69+
val outgoingPaymentFactory = NodeRelay.SimpleOutgoingPaymentFactory(nodeParams, router, register)
70+
val handler = context.spawn(NodeRelay.apply(nodeParams, context.self, register, relayId, paymentHash, outgoingPaymentFactory), relayId.toString)
7071
context.log.debug("forwarding incoming htlc to new handler")
7172
handler ! NodeRelay.Relay(nodeRelayPacket)
7273
apply(nodeParams, router, register, children + (paymentHash -> handler))

eclair-core/src/main/scala/fr/acinq/eclair/payment/send/MultiPartPaymentLifecycle.scala

+4-4
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ import java.util.concurrent.TimeUnit
4444
* Sender for a multi-part payment (see https://github.com/lightningnetwork/lightning-rfc/blob/master/04-onion-routing.md#basic-multi-part-payments).
4545
* The payment will be split into multiple sub-payments that will be sent in parallel.
4646
*/
47-
class MultiPartPaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: ActorRef, register: ActorRef) extends FSMDiagnosticActorLogging[MultiPartPaymentLifecycle.State, MultiPartPaymentLifecycle.Data] {
47+
class MultiPartPaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: ActorRef, paymentFactory: PaymentInitiator.PaymentFactory) extends FSMDiagnosticActorLogging[MultiPartPaymentLifecycle.State, MultiPartPaymentLifecycle.Data] {
4848

4949
import MultiPartPaymentLifecycle._
5050

@@ -202,13 +202,13 @@ class MultiPartPaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig,
202202
case Event(_: Status.Failure, _) => stay
203203
}
204204

205-
def spawnChildPaymentFsm(childId: UUID): ActorRef = {
205+
private def spawnChildPaymentFsm(childId: UUID): ActorRef = {
206206
val upstream = cfg.upstream match {
207207
case Upstream.Local(_) => Upstream.Local(childId)
208208
case _ => cfg.upstream
209209
}
210210
val childCfg = cfg.copy(id = childId, publishEvent = false, upstream = upstream)
211-
context.actorOf(PaymentLifecycle.props(nodeParams, childCfg, router, register))
211+
paymentFactory.spawnOutgoingPayment(context, childCfg)
212212
}
213213

214214
private def gotoAbortedOrStop(d: PaymentAborted): State = {
@@ -265,7 +265,7 @@ class MultiPartPaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig,
265265

266266
object MultiPartPaymentLifecycle {
267267

268-
def props(nodeParams: NodeParams, cfg: SendPaymentConfig, router: ActorRef, register: ActorRef) = Props(new MultiPartPaymentLifecycle(nodeParams, cfg, router, register))
268+
def props(nodeParams: NodeParams, cfg: SendPaymentConfig, router: ActorRef, paymentFactory: PaymentInitiator.PaymentFactory) = Props(new MultiPartPaymentLifecycle(nodeParams, cfg, router, paymentFactory))
269269

270270
/**
271271
* Send a payment to a given node. The payment may be split into multiple child payments, for which a path-finding

0 commit comments

Comments
 (0)