Skip to content

Commit

Permalink
Change type architecture for onion per-hop payload.
Browse files Browse the repository at this point in the history
Explicitly expand the matrix of possible types (relay/final, legacy/tlv).
  • Loading branch information
t-bast committed Sep 4, 2019
1 parent 589690b commit 9c81ef5
Show file tree
Hide file tree
Showing 12 changed files with 374 additions and 359 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ import akka.actor.{Actor, ActorLogging, ActorRef, Props}
import fr.acinq.bitcoin.ByteVector32
import fr.acinq.bitcoin.Crypto.PublicKey
import fr.acinq.eclair.channel.Channel
import fr.acinq.eclair.payment.PaymentLifecycle.{LegacyPayload, SendPayment, SendPaymentToRoute}
import fr.acinq.eclair.payment.PaymentLifecycle.{SendPayment, SendPaymentToRoute}
import fr.acinq.eclair.payment.PaymentRequest.ExtraHop
import fr.acinq.eclair.router.RouteParams
import fr.acinq.eclair.wire.Onion.FinalLegacyPayload
import fr.acinq.eclair.{CltvExpiryDelta, MilliSatoshi, NodeParams}

/**
Expand All @@ -39,8 +40,8 @@ class PaymentInitiator(nodeParams: NodeParams, router: ActorRef, register: Actor
val finalExpiry = (p.finalExpiryDelta + 1).toCltvExpiry
val payFsm = context.actorOf(PaymentLifecycle.props(nodeParams, paymentId, router, register))
p.predefinedRoute match {
case Nil => payFsm forward SendPayment(p.paymentHash, p.targetNodeId, LegacyPayload(p.amount, finalExpiry), p.maxAttempts, p.assistedRoutes, p.routeParams)
case hops => payFsm forward SendPaymentToRoute(p.paymentHash, hops, LegacyPayload(p.amount, finalExpiry))
case Nil => payFsm forward SendPayment(p.paymentHash, p.targetNodeId, FinalLegacyPayload(p.amount, finalExpiry), p.maxAttempts, p.assistedRoutes, p.routeParams)
case hops => payFsm forward SendPaymentToRoute(p.paymentHash, hops, FinalLegacyPayload(p.amount, finalExpiry))
}
sender ! paymentId
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import fr.acinq.eclair.db.{OutgoingPayment, OutgoingPaymentStatus}
import fr.acinq.eclair.payment.PaymentLifecycle._
import fr.acinq.eclair.payment.PaymentRequest.ExtraHop
import fr.acinq.eclair.router._
import fr.acinq.eclair.wire.OnionPerHopPayload._
import fr.acinq.eclair.wire.Onion._
import fr.acinq.eclair.wire._
import scodec.Attempt
import scodec.bits.ByteVector
Expand All @@ -47,22 +47,22 @@ class PaymentLifecycle(nodeParams: NodeParams, id: UUID, router: ActorRef, regis

when(WAITING_FOR_REQUEST) {
case Event(c: SendPaymentToRoute, WaitingForRequest) =>
val send = SendPayment(c.paymentHash, c.hops.last, c.paymentOptions, maxAttempts = 1)
paymentsDb.addOutgoingPayment(OutgoingPayment(id, c.paymentHash, None, c.paymentOptions.finalAmount, Platform.currentTime, None, OutgoingPaymentStatus.PENDING))
val send = SendPayment(c.paymentHash, c.hops.last, c.finalPayload, maxAttempts = 1)
paymentsDb.addOutgoingPayment(OutgoingPayment(id, c.paymentHash, None, c.finalPayload.amount, Platform.currentTime, None, OutgoingPaymentStatus.PENDING))
router ! FinalizeRoute(c.hops)
goto(WAITING_FOR_ROUTE) using WaitingForRoute(sender, send, failures = Nil)

case Event(c: SendPayment, WaitingForRequest) =>
router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.paymentOptions.finalAmount, c.assistedRoutes, routeParams = c.routeParams)
paymentsDb.addOutgoingPayment(OutgoingPayment(id, c.paymentHash, None, c.paymentOptions.finalAmount, Platform.currentTime, None, OutgoingPaymentStatus.PENDING))
router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.finalPayload.amount, c.assistedRoutes, routeParams = c.routeParams)
paymentsDb.addOutgoingPayment(OutgoingPayment(id, c.paymentHash, None, c.finalPayload.amount, Platform.currentTime, None, OutgoingPaymentStatus.PENDING))
goto(WAITING_FOR_ROUTE) using WaitingForRoute(sender, c, failures = Nil)
}

when(WAITING_FOR_ROUTE) {
case Event(RouteResponse(hops, ignoreNodes, ignoreChannels), WaitingForRoute(s, c, failures)) =>
log.info(s"route found: attempt=${failures.size + 1}/${c.maxAttempts} route=${hops.map(_.nextNodeId).mkString("->")} channels=${hops.map(_.lastUpdate.shortChannelId).mkString("->")}")
val firstHop = hops.head
val (cmd, sharedSecrets) = buildCommand(id, c.paymentHash, hops, c.paymentOptions)
val (cmd, sharedSecrets) = buildCommand(id, c.paymentHash, hops, c.finalPayload)
register ! Register.ForwardShortId(firstHop.lastUpdate.shortChannelId, cmd)
goto(WAITING_FOR_PAYMENT_COMPLETE) using WaitingForComplete(s, c, cmd, failures, sharedSecrets, ignoreNodes, ignoreChannels, hops)

Expand All @@ -78,7 +78,7 @@ class PaymentLifecycle(nodeParams: NodeParams, id: UUID, router: ActorRef, regis
case Event(fulfill: UpdateFulfillHtlc, WaitingForComplete(s, c, cmd, _, _, _, _, hops)) =>
paymentsDb.updateOutgoingPayment(id, OutgoingPaymentStatus.SUCCEEDED, preimage = Some(fulfill.paymentPreimage))
reply(s, PaymentSucceeded(id, cmd.amount, c.paymentHash, fulfill.paymentPreimage, hops))
context.system.eventStream.publish(PaymentSent(id, c.paymentOptions.finalAmount, cmd.amount - c.paymentOptions.finalAmount, cmd.paymentHash, fulfill.paymentPreimage, fulfill.channelId))
context.system.eventStream.publish(PaymentSent(id, c.finalPayload.amount, cmd.amount - c.finalPayload.amount, cmd.paymentHash, fulfill.paymentPreimage, fulfill.channelId))
stop(FSM.Normal)

case Event(fail: UpdateFailHtlc, WaitingForComplete(s, c, _, failures, sharedSecrets, ignoreNodes, ignoreChannels, hops)) =>
Expand Down Expand Up @@ -108,12 +108,12 @@ class PaymentLifecycle(nodeParams: NodeParams, id: UUID, router: ActorRef, regis
// in that case we don't know which node is sending garbage, let's try to blacklist all nodes except the one we are directly connected to and the destination node
val blacklist = hops.map(_.nextNodeId).drop(1).dropRight(1)
log.warning(s"blacklisting intermediate nodes=${blacklist.mkString(",")}")
router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.paymentOptions.finalAmount, c.assistedRoutes, ignoreNodes ++ blacklist, ignoreChannels, c.routeParams)
router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.finalPayload.amount, c.assistedRoutes, ignoreNodes ++ blacklist, ignoreChannels, c.routeParams)
goto(WAITING_FOR_ROUTE) using WaitingForRoute(s, c, failures :+ UnreadableRemoteFailure(hops))
case Success(e@Sphinx.DecryptedFailurePacket(nodeId, failureMessage: Node)) =>
log.info(s"received 'Node' type error message from nodeId=$nodeId, trying to route around it (failure=$failureMessage)")
// let's try to route around this node
router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.paymentOptions.finalAmount, c.assistedRoutes, ignoreNodes + nodeId, ignoreChannels, c.routeParams)
router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.finalPayload.amount, c.assistedRoutes, ignoreNodes + nodeId, ignoreChannels, c.routeParams)
goto(WAITING_FOR_ROUTE) using WaitingForRoute(s, c, failures :+ RemoteFailure(hops, e))
case Success(e@Sphinx.DecryptedFailurePacket(nodeId, failureMessage: Update)) =>
log.info(s"received 'Update' type error message from nodeId=$nodeId, retrying payment (failure=$failureMessage)")
Expand Down Expand Up @@ -141,18 +141,18 @@ class PaymentLifecycle(nodeParams: NodeParams, id: UUID, router: ActorRef, regis
// in any case, we forward the update to the router
router ! failureMessage.update
// let's try again, router will have updated its state
router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.paymentOptions.finalAmount, c.assistedRoutes, ignoreNodes, ignoreChannels, c.routeParams)
router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.finalPayload.amount, c.assistedRoutes, ignoreNodes, ignoreChannels, c.routeParams)
} else {
// this node is fishy, it gave us a bad sig!! let's filter it out
log.warning(s"got bad signature from node=$nodeId update=${failureMessage.update}")
router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.paymentOptions.finalAmount, c.assistedRoutes, ignoreNodes + nodeId, ignoreChannels, c.routeParams)
router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.finalPayload.amount, c.assistedRoutes, ignoreNodes + nodeId, ignoreChannels, c.routeParams)
}
goto(WAITING_FOR_ROUTE) using WaitingForRoute(s, c, failures :+ RemoteFailure(hops, e))
case Success(e@Sphinx.DecryptedFailurePacket(nodeId, failureMessage)) =>
log.info(s"received an error message from nodeId=$nodeId, trying to use a different channel (failure=$failureMessage)")
// let's try again without the channel outgoing from nodeId
val faultyChannel = hops.find(_.nodeId == nodeId).map(hop => ChannelDesc(hop.lastUpdate.shortChannelId, hop.nodeId, hop.nextNodeId))
router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.paymentOptions.finalAmount, c.assistedRoutes, ignoreNodes, ignoreChannels ++ faultyChannel.toSet, c.routeParams)
router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.finalPayload.amount, c.assistedRoutes, ignoreNodes, ignoreChannels ++ faultyChannel.toSet, c.routeParams)
goto(WAITING_FOR_ROUTE) using WaitingForRoute(s, c, failures :+ RemoteFailure(hops, e))
}

Expand All @@ -172,7 +172,7 @@ class PaymentLifecycle(nodeParams: NodeParams, id: UUID, router: ActorRef, regis
} else {
log.info(s"received an error message from local, trying to use a different channel (failure=${t.getMessage})")
val faultyChannel = ChannelDesc(hops.head.lastUpdate.shortChannelId, hops.head.nodeId, hops.head.nextNodeId)
router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.paymentOptions.finalAmount, c.assistedRoutes, ignoreNodes, ignoreChannels + faultyChannel, c.routeParams)
router ! RouteRequest(nodeParams.nodeId, c.targetNodeId, c.finalPayload.amount, c.assistedRoutes, ignoreNodes, ignoreChannels + faultyChannel, c.routeParams)
goto(WAITING_FOR_ROUTE) using WaitingForRoute(s, c, failures :+ LocalFailure(t))
}

Expand All @@ -196,14 +196,14 @@ object PaymentLifecycle {

// @formatter:off
case class ReceivePayment(amount_opt: Option[MilliSatoshi], description: String, expirySeconds_opt: Option[Long] = None, extraHops: List[List[ExtraHop]] = Nil, fallbackAddress: Option[String] = None, paymentPreimage: Option[ByteVector32] = None)
case class SendPaymentToRoute(paymentHash: ByteVector32, hops: Seq[PublicKey], paymentOptions: PaymentOptions)
case class SendPaymentToRoute(paymentHash: ByteVector32, hops: Seq[PublicKey], finalPayload: FinalPayload)
case class SendPayment(paymentHash: ByteVector32,
targetNodeId: PublicKey,
paymentOptions: PaymentOptions,
finalPayload: FinalPayload,
maxAttempts: Int,
assistedRoutes: Seq[Seq[ExtraHop]] = Nil,
routeParams: Option[RouteParams] = None) {
require(paymentOptions.finalAmount > 0.msat, s"amount must be > 0")
require(finalPayload.amount > 0.msat, s"amount must be > 0")
}

sealed trait PaymentResult
Expand All @@ -214,18 +214,6 @@ object PaymentLifecycle {
case class UnreadableRemoteFailure(route: Seq[Hop]) extends PaymentFailure
case class PaymentFailed(id: UUID, paymentHash: ByteVector32, failures: Seq[PaymentFailure]) extends PaymentResult

/**
* Options to help build the final payload of the payment route.
*/
sealed trait PaymentOptions {
// The final htlc amount in millisatoshis.
val finalAmount: MilliSatoshi
// The final htlc expiry in number of blocks.
val finalExpiry: CltvExpiry
}
case class LegacyPayload(finalAmount: MilliSatoshi, finalExpiry: CltvExpiry) extends PaymentOptions
case class TlvPayload(finalAmount: MilliSatoshi, finalExpiry: CltvExpiry, records: Seq[OnionTlv] = Nil) extends PaymentOptions

sealed trait Data
case object WaitingForRequest extends Data
case class WaitingForRoute(sender: ActorRef, c: SendPayment, failures: Seq[PaymentFailure]) extends Data
Expand All @@ -237,11 +225,14 @@ object PaymentLifecycle {
case object WAITING_FOR_PAYMENT_COMPLETE extends State
// @formatter:on

def buildOnion(nodes: Seq[PublicKey], payloads: Seq[OnionPerHopPayload], associatedData: ByteVector32): Sphinx.PacketAndSecrets = {
def buildOnion(nodes: Seq[PublicKey], payloads: Seq[PerHopPayload], associatedData: ByteVector32): Sphinx.PacketAndSecrets = {
require(nodes.size == payloads.size)
val sessionKey = randomKey
val payloadsBin: Seq[ByteVector] = payloads
.map(OnionCodecs.perHopPayloadCodec.encode)
.map({
case p: FinalPayload => OnionCodecs.finalPerHopPayloadCodec.encode(p)
case p: RelayPayload => OnionCodecs.relayPerHopPayloadCodec.encode(p)
})
.map {
case Attempt.Successful(bitVector) => bitVector.toByteVector
case Attempt.Failure(cause) => throw new RuntimeException(s"serialization error: $cause")
Expand All @@ -252,29 +243,25 @@ object PaymentLifecycle {
/**
* Build the onion payloads for each hop.
*
* @param hops the hops as computed by the router + extra routes from payment request
* @param opts options to help build each hop's payload (final amount, expiry, additional tlv records, etc)
* @param hops the hops as computed by the router + extra routes from payment request
* @param finalPayload payload data for the final node (amount, expiry, additional tlv records, etc)
* @return a (firstAmount, firstExpiry, payloads) tuple where:
* - firstAmount is the amount for the first htlc in the route
* - firstExpiry is the cltv expiry for the first htlc in the route
* - a sequence of payloads that will be used to build the onion
*/
def buildPayloads(hops: Seq[Hop], opts: PaymentOptions): (MilliSatoshi, CltvExpiry, Seq[OnionPerHopPayload]) = {
val finalPayload: Seq[OnionPerHopPayload] = opts match {
case p: LegacyPayload => OnionForwardInfo(ShortChannelId(0L), p.finalAmount, p.finalExpiry) :: Nil
case p: TlvPayload => TlvStream[OnionTlv](OnionTlv.AmountToForward(p.finalAmount) +: OnionTlv.OutgoingCltv(p.finalExpiry) +: p.records) :: Nil
}
hops.reverse.foldLeft((opts.finalAmount, opts.finalExpiry, finalPayload)) {
def buildPayloads(hops: Seq[Hop], finalPayload: FinalPayload): (MilliSatoshi, CltvExpiry, Seq[PerHopPayload]) = {
hops.reverse.foldLeft((finalPayload.amount, finalPayload.expiry, Seq[PerHopPayload](finalPayload))) {
case ((amount, expiry, payloads), hop) =>
val nextFee = nodeFee(hop.lastUpdate.feeBaseMsat, hop.lastUpdate.feeProportionalMillionths, amount)
// Since we don't have any scenario where we add tlv data for intermediate hops, we use legacy payloads.
val payload: OnionPerHopPayload = OnionForwardInfo(hop.lastUpdate.shortChannelId, amount, expiry)
val payload = RelayLegacyPayload(hop.lastUpdate.shortChannelId, amount, expiry)
(amount + nextFee, expiry + hop.lastUpdate.cltvExpiryDelta, payload +: payloads)
}
}

def buildCommand(id: UUID, paymentHash: ByteVector32, hops: Seq[Hop], opts: PaymentOptions): (CMD_ADD_HTLC, Seq[(ByteVector32, PublicKey)]) = {
val (firstAmount, firstExpiry, payloads) = buildPayloads(hops.drop(1), opts)
def buildCommand(id: UUID, paymentHash: ByteVector32, hops: Seq[Hop], finalPayload: FinalPayload): (CMD_ADD_HTLC, Seq[(ByteVector32, PublicKey)]) = {
val (firstAmount, firstExpiry, payloads) = buildPayloads(hops.drop(1), finalPayload)
val nodes = hops.map(_.nextNodeId)
// BOLT 2 requires that associatedData == paymentHash
val onion = buildOnion(nodes, payloads, paymentHash)
Expand Down
Loading

0 comments on commit 9c81ef5

Please sign in to comment.