Skip to content

Commit e5429eb

Browse files
Avoid visiting vertices multiple times in Dijkstra's algorithm (#1745)
Fix implementation of Dijsktra to visit each node only once. This is a great speed-up in the worst-case scenario.
1 parent c37eb1a commit e5429eb

File tree

1 file changed

+25
-25
lines changed
  • eclair-core/src/main/scala/fr/acinq/eclair/router

1 file changed

+25
-25
lines changed

eclair-core/src/main/scala/fr/acinq/eclair/router/Graph.scala

+25-25
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ object Graph {
199199
val bestEdges = mutable.HashMap.newBuilder[PublicKey, GraphEdge](initialCapacity, mutable.HashMap.defaultLoadFactor).result()
200200
// NB: we want the elements with smallest weight first, hence the `reverse`.
201201
val toExplore = mutable.PriorityQueue.empty[WeightedNode](NodeComparator.reverse)
202+
val visitedNodes = mutable.HashSet[PublicKey]()
202203

203204
// initialize the queue and cost array with the initial weight
204205
bestWeights.put(targetNode, initialWeight)
@@ -208,8 +209,9 @@ object Graph {
208209
while (toExplore.nonEmpty && !targetFound) {
209210
// node with the smallest distance from the target
210211
val current = toExplore.dequeue() // O(log(n))
211-
if (current.key != sourceNode) {
212-
val currentWeight = bestWeights(current.key) // NB: there is always an entry for the current in the 'bestWeights' map
212+
targetFound = current.key == sourceNode
213+
if (!targetFound && !visitedNodes.contains(current.key)) {
214+
visitedNodes += current.key
213215
// build the neighbors with optional extra edges
214216
val neighborEdges = {
215217
val extraNeighbors = extraEdges.filter(_.desc.b == current.key)
@@ -220,11 +222,11 @@ object Graph {
220222
val neighbor = edge.desc.a
221223
// NB: this contains the amount (including fees) that will need to be sent to `neighbor`, but the amount that
222224
// will be relayed through that edge is the one in `currentWeight`.
223-
val neighborWeight = addEdgeWeight(sender, edge, currentWeight, currentBlockHeight, wr)
224-
val canRelayAmount = currentWeight.cost <= edge.capacity &&
225-
edge.balance_opt.forall(currentWeight.cost <= _) &&
226-
edge.update.htlcMaximumMsat.forall(currentWeight.cost <= _) &&
227-
currentWeight.cost >= edge.update.htlcMinimumMsat
225+
val neighborWeight = addEdgeWeight(sender, edge, current.weight, currentBlockHeight, wr)
226+
val canRelayAmount = current.weight.cost <= edge.capacity &&
227+
edge.balance_opt.forall(current.weight.cost <= _) &&
228+
edge.update.htlcMaximumMsat.forall(current.weight.cost <= _) &&
229+
current.weight.cost >= edge.update.htlcMinimumMsat
228230
if (canRelayAmount && boundaries(neighborWeight) && !ignoredEdges.contains(edge.desc) && !ignoredVertices.contains(neighbor)) {
229231
val previousNeighborWeight = bestWeights.getOrElse(neighbor, RichWeight(MilliSatoshi(Long.MaxValue), Int.MaxValue, CltvExpiryDelta(Int.MaxValue), Double.MaxValue))
230232
// if this path between neighbor and the target has a shorter distance than previously known, we select it
@@ -238,21 +240,19 @@ object Graph {
238240
}
239241
}
240242
}
241-
} else {
242-
targetFound = true
243243
}
244244
}
245245

246-
targetFound match {
247-
case false => Seq.empty[GraphEdge]
248-
case true =>
249-
val edgePath = new mutable.ArrayBuffer[GraphEdge](RouteCalculation.ROUTE_MAX_LENGTH)
250-
var current = bestEdges.get(sourceNode)
251-
while (current.isDefined) {
252-
edgePath += current.get
253-
current = bestEdges.get(current.get.desc.b)
254-
}
255-
edgePath.toSeq
246+
if (targetFound) {
247+
val edgePath = new mutable.ArrayBuffer[GraphEdge](RouteCalculation.ROUTE_MAX_LENGTH)
248+
var current = bestEdges.get(sourceNode)
249+
while (current.isDefined) {
250+
edgePath += current.get
251+
current = bestEdges.get(current.get.desc.b)
252+
}
253+
edgePath.toSeq
254+
} else {
255+
Seq.empty[GraphEdge]
256256
}
257257
}
258258

@@ -421,9 +421,10 @@ object Graph {
421421
* @return a new graph without this edge
422422
*/
423423
def removeEdge(desc: ChannelDesc): DirectedGraph = {
424-
containsEdge(desc) match {
425-
case true => DirectedGraph(vertices.updated(desc.b, vertices(desc.b).filterNot(_.desc == desc)))
426-
case false => this
424+
if (containsEdge(desc)) {
425+
DirectedGraph(vertices.updated(desc.b, vertices(desc.b).filterNot(_.desc == desc)))
426+
} else {
427+
this
427428
}
428429
}
429430

@@ -555,9 +556,8 @@ object Graph {
555556

556557
def addDescToMap(desc: ChannelDesc, u: ChannelUpdate, capacity: Satoshi, balance_opt: Option[MilliSatoshi]): Unit = {
557558
mutableMap.put(desc.b, GraphEdge(desc, u, getCapacity(capacity, u), balance_opt) +: mutableMap.getOrElse(desc.b, List.empty[GraphEdge]))
558-
mutableMap.get(desc.a) match {
559-
case None => mutableMap += desc.a -> List.empty[GraphEdge]
560-
case _ =>
559+
if (!mutableMap.contains(desc.a)) {
560+
mutableMap += desc.a -> List.empty[GraphEdge]
561561
}
562562
}
563563

0 commit comments

Comments
 (0)