Skip to content

Commit 9fe55e1

Browse files
committed
Optimize building of links data by using squared values instead of roots.
This requires the consumer to perform their own math.sqrt() but it reduces the number of them done during loading, etc, and in most cases you can simply compare the square of your desired distance against the stored value. If you are limiting to a max of 5 lightyears, then you are limiting to a max of 5^2 lightyears^2.
1 parent d57d794 commit 9fe55e1

File tree

2 files changed

+24
-24
lines changed

2 files changed

+24
-24
lines changed

trade.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -736,19 +736,20 @@ def localCommand(tdb, args):
736736
args.ship = ship
737737
if args.ly is None: args.ly = (ship.maxLyFull if args.full else ship.maxLyEmpty)
738738
ly = args.ly or tdb.maxSystemLinkLy
739+
lySq = ly ** 2
739740

740741
tdb.buildLinks()
741742

742743
printHeading("Local systems to {} within {} ly.".format(srcSystem.name(), ly))
743744

744745
distances = { }
745746

746-
for (destSys, destDist) in srcSystem.links.items():
747+
for (destSys, destDistSq) in srcSystem.links.items():
747748
if args.debug:
748749
print("Checking {} dist={:5.2f}".format(destSys.str(), destDist))
749-
if destDist > ly:
750+
if destDist > lySq:
750751
continue
751-
distances[destSys] = destDist
752+
distances[destSys] = math.sqrt(destDistSq)
752753

753754
for (system, dist) in sorted(distances.items(), key=lambda x: x[1]):
754755
pillLength = ""
@@ -776,6 +777,7 @@ def navCommand(tdb, args):
776777
args.ship = ship
777778
if args.maxLyPer is None: args.maxLyPer = (ship.maxLyFull if args.full else ship.maxLyEmpty)
778779
maxLyPer = args.maxLyPer or tdb.maxSystemLinkLy
780+
maxLyPerSq = maxLyPer ** 2
779781

780782
if args.debug:
781783
print("# Route from {} to {} with max {} ly per jump.".format(srcSystem.name(), dstSystem.name(), maxLyPer))
@@ -791,10 +793,12 @@ def navCommand(tdb, args):
791793
# nodes that are this many hops out and then clear the list.
792794
openNodes, openList = openList, {}
793795

794-
for (node, startDist) in openNodes.items():
795-
for (destSys, destDist) in node.links.items():
796-
if destDist > maxLyPer:
796+
for (node, startDistSq) in openNodes.items():
797+
startDist = math.sqrt(startDistSq)
798+
for (destSys, destDistSq) in node.links.items():
799+
if destDistSq > maxLyPerSq:
797800
continue
801+
destDist = math.sqrt(destDistSq)
798802
dist = startDist + destDist
799803
# If we already have a shorter path, do nothing
800804
try:

tradedb.py

+14-18
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,6 @@ def __init__(self, ID, dbname, posX, posY, posZ):
6969
self.stations = []
7070

7171

72-
@staticmethod
73-
def linkSystems(lhs, rhs, distSq):
74-
lhs.links[rhs] = rhs.links[lhs] = math.sqrt(distSq)
75-
76-
7772
def addStation(self, station):
7873
if not station in self.stations:
7974
self.stations.append(station)
@@ -116,6 +111,7 @@ def getDestinations(self, maxJumps=None, maxLyPer=None, avoiding=None):
116111
avoiding = avoiding or []
117112
maxJumps = maxJumps or sys.maxsize
118113
maxLyPer = maxLyPer or float("inf")
114+
maxLyPerSq = maxLyPer ** 2
119115

120116
# The open list is the list of nodes we should consider next for
121117
# potential destinations.
@@ -124,13 +120,13 @@ def getDestinations(self, maxJumps=None, maxLyPer=None, avoiding=None):
124120
# The closed list is the list of nodes we've already been to (so
125121
# that we don't create loops A->B->C->A->B->C->...)
126122

127-
Node = namedtuple('Node', [ 'system', 'via', 'distLy' ])
123+
Node = namedtuple('Node', [ 'system', 'via', 'distLySq' ])
128124

129125
openList = [ Node(self.system, [], 0) ]
130-
pathList = { system.ID: Node(system, None, 0.0)
126+
pathList = { system.ID: Node(system, None, -1.0)
131127
# include avoids so we only have
132128
# to consult one place for exclusions
133-
for system in avoiding + [ self ]
129+
for system in avoiding
134130
# the avoid list may contain stations,
135131
# which affects destinations but not vias
136132
if isinstance(system, System) }
@@ -146,18 +142,18 @@ def getDestinations(self, maxJumps=None, maxLyPer=None, avoiding=None):
146142
jumps += 1
147143

148144
for node in ring:
149-
for (destSys, destDist) in node.system.links.items():
150-
if destDist > maxLyPer: continue
151-
dist = node.distLy + destDist
145+
for (destSys, destDistSq) in node.system.links.items():
146+
if destDistSq > maxLyPerSq: continue
147+
distSq = node.distLySq + destDistSq
152148
# If we already have a shorter path, do nothing
153149
try:
154-
if dist >= pathList[destSys.ID].distLy: continue
150+
if distSq >= pathList[destSys.ID].distLySq: continue
155151
except KeyError: pass
156152
# Add to the path list
157-
pathList[destSys.ID] = Node(destSys, node.via, dist)
153+
pathList[destSys.ID] = Node(destSys, node.via, distSq)
158154
# Add to the open list but also include node to the via
159155
# list so that it serves as the via list for all next-hops.
160-
openList += [ Node(destSys, node.via + [destSys], dist) ]
156+
openList += [ Node(destSys, node.via + [destSys], distSq) ]
161157

162158
Destination = namedtuple('Destination', [ 'system', 'station', 'via', 'distLy' ])
163159

@@ -173,10 +169,10 @@ def getDestinations(self, maxJumps=None, maxLyPer=None, avoiding=None):
173169
avoidStations = [ station for station in avoiding if isinstance(station, Station) ]
174170
epsilon = sys.float_info.epsilon
175171
for node in pathList.values():
176-
if node.distLy > epsilon: # Values indistinguishable from zero are avoidances
172+
if node.distLySq >= 0.0: # Values indistinguishable from zero are avoidances
177173
for station in node.system.stations:
178174
if not station in avoidStations:
179-
destStations += [ Destination(node.system, station, [self.system] + node.via + [station.system], node.distLy) ]
175+
destStations += [ Destination(node.system, station, [self.system] + node.via + [station.system], math.sqrt(node.distLySq)) ]
180176

181177
return destStations
182178

@@ -468,8 +464,8 @@ def buildLinks(self):
468464
dX, dY, dZ = rhs.posX - lhs.posX, rhs.posY - lhs.posY, rhs.posZ - lhs.posZ
469465
distSq = (dX * dX) + (dY * dY) + (dZ * dZ)
470466
if distSq <= longestJumpSq:
471-
System.linkSystems(lhs, rhs, distSq)
472-
self.numLinks += 1
467+
lhs.links[rhs] = rhs.links[lhs] = distSq
468+
self.numLinks += 1
473469

474470
if self.debug > 2: print("# Number of links between systems: %d" % self.numLinks)
475471

0 commit comments

Comments
 (0)