Skip to content

Commit

Permalink
Add support for arbitrary length onion errors
Browse files Browse the repository at this point in the history
The specification recommends using a length of 256 for onion errors, but
it doesn't say that we should reject errors that use a different length.

We may want to start creating errors with a bigger length than 256 if we
need to transmit more data to the sender. In order to prepare for this,
we keep creating 256-bytes onion errors, but allow receiving errors of
arbitrary length.
  • Loading branch information
t-bast committed Sep 30, 2022
1 parent 8ed94c5 commit d5adba0
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 40 deletions.
52 changes: 17 additions & 35 deletions src/commonMain/kotlin/fr/acinq/lightning/crypto/sphinx/Sphinx.kt
Original file line number Diff line number Diff line change
Expand Up @@ -292,37 +292,33 @@ data class DecryptedFailurePacket(val originNode: PublicKey, val failureMessage:
* +----------------+----------------------------------+-----------------+----------------------+-----+
* | HMAC(32 bytes) | failure message length (2 bytes) | failure message | pad length (2 bytes) | pad |
* +----------------+----------------------------------+-----------------+----------------------+-----+
* with failure message length + pad length = 256
* Bolt 4: SHOULD set pad such that the failure_len plus pad_len is equal to 256
*/
object FailurePacket {

private const val MaxPayloadLength = 256
private const val PacketLength = Sphinx.MacLength + MaxPayloadLength + 2 + 2
private const val RecommendedPayloadLength = 256

fun encode(failure: FailureMessage, macKey: ByteVector32): ByteArray {
fun encode(failure: FailureMessage, macKey: ByteVector32, payloadLength: Int = RecommendedPayloadLength): ByteArray {
val out = ByteArrayOutput()
val failureMessageBin = FailureMessage.encode(failure)
require(failureMessageBin.size <= MaxPayloadLength) { "encoded failure message overflows onion" }
require(failureMessageBin.size <= payloadLength) { "encoded failure message overflows onion" }
LightningCodecs.writeU16(failureMessageBin.size, out)
LightningCodecs.writeBytes(failureMessageBin, out)
val padLen = MaxPayloadLength - failureMessageBin.size
val padLen = payloadLength - failureMessageBin.size
LightningCodecs.writeU16(padLen, out)
LightningCodecs.writeBytes(ByteArray(padLen), out)
val packet = out.toByteArray()
return Sphinx.mac(macKey.toByteArray(), packet).toByteArray() + packet
}

fun decode(input: ByteArray, macKey: ByteVector32): Try<FailureMessage> {
if (input.size != PacketLength) {
return Try.Failure(IllegalArgumentException("invalid error packet length: ${Hex.encode(input)}"))
}
val mac = input.take(32).toByteArray().toByteVector32()
val payload = input.drop(32).toByteArray()
if (Sphinx.mac(macKey.toByteArray(), payload) != mac) {
val packet = input.drop(32).toByteArray()
if (Sphinx.mac(macKey.toByteArray(), packet) != mac) {
return Try.Failure(IllegalArgumentException("invalid error packet mac: ${Hex.encode(input)}"))
}
val stream = ByteArrayInput(payload)
return runTrying { FailureMessage.decode(LightningCodecs.bytes(stream, LightningCodecs.u16(stream))) }
val payload = ByteArrayInput(packet)
return runTrying { FailureMessage.decode(LightningCodecs.bytes(payload, LightningCodecs.u16(payload))) }
}

/**
Expand All @@ -348,18 +344,10 @@ object FailurePacket {
* @param sharedSecret destination node's shared secret.
* @return an encrypted failure packet that can be sent to the destination node.
*/
fun wrap(packet: ByteArray, sharedSecret: ByteVector32): ByteArray = tryWrap(packet, sharedSecret).get()

private fun tryWrap(packet: ByteArray, sharedSecret: ByteVector32): Try<ByteArray> {
if (packet.size != PacketLength) {
val ex = IllegalArgumentException("invalid error packet length ${packet.size}, must be $PacketLength (malicious or buggy downstream node)")
return Try.Failure(ex)
}
fun wrap(packet: ByteArray, sharedSecret: ByteVector32): ByteArray {
val key = Sphinx.generateKey("ammag", sharedSecret)
val stream = Sphinx.generateStream(key, PacketLength)
// If we received a packet with an invalid length, we trim and pad to forward a packet with a normal length upstream.
// This is a poor man's attempt at increasing the likelihood of the sender receiving the error.
return Try.Success(packet.take(PacketLength).toByteArray().leftPaddedCopyOf(PacketLength) xor stream)
val stream = Sphinx.generateStream(key, packet.size)
return packet xor stream
}

/**
Expand All @@ -373,23 +361,17 @@ object FailurePacket {
* decrypted, Failure otherwise.
*/
fun decrypt(packet: ByteArray, sharedSecrets: SharedSecrets): Try<DecryptedFailurePacket> {
require(packet.size == PacketLength) { "invalid error packet length ${packet.size}, must be $PacketLength" }

fun loop(packet: ByteArray, secrets: List<Pair<ByteVector32, PublicKey>>): Try<DecryptedFailurePacket> {
return if (secrets.isEmpty()) {
val ex = IllegalArgumentException("couldn't parse error packet=$packet with sharedSecrets=$secrets")
Try.Failure(ex)
} else {
val (secret, pubkey) = secrets.first()
when (val packet1 = tryWrap(packet, secret)) {
is Try.Failure -> Try.Failure(packet1.error)
is Try.Success -> {
val um = Sphinx.generateKey("um", secret)
when (val error = decode(packet1.result, um)) {
is Try.Failure -> loop(packet1.result, secrets.tail())
is Try.Success -> Try.Success(DecryptedFailurePacket(pubkey, error.result))
}
}
val packet1 = wrap(packet, secret)
val um = Sphinx.generateKey("um", secret)
when (val error = decode(packet1, um)) {
is Try.Failure -> loop(packet1, secrets.tail())
is Try.Success -> Try.Success(DecryptedFailurePacket(pubkey, error.result))
}
}
}
Expand Down
Loading

0 comments on commit d5adba0

Please sign in to comment.