Skip to content

Commit

Permalink
Sqrt is not the enemy, math.sqrt() is the enemy.
Browse files Browse the repository at this point in the history
I realized that it's the function call part of math.sqrt not the actual root that makes it so expensive. Timing suggests that calling "math.sqrt" costs ~300ns on a 1st gen i7 vs 29ns to do "** 0.5", so I introduced a "Station.distanceTo" call and eliminated all calls to math.sqrt in favor of ** 0.5.
  • Loading branch information
kfsone committed Feb 6, 2015
1 parent d756b9e commit b81e6b0
Show file tree
Hide file tree
Showing 8 changed files with 80 additions and 68 deletions.
5 changes: 5 additions & 0 deletions CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
TradeDangerous, Copyright (C) Oliver "kfsone" Smith, July 2014
==============================================================================

[wip] Feb 05 2015
. (kfsone) Added "distanceTo" function to "System" which returns the
distance in ly to a second system,
. (kfsone) Minor performance gain in distance calculations,

v6.8.5 Feb 04 2015
. (kfsone) Added "trade" command to list station-to-station trades,
. (kfsone) "station" command now lists 5 sell and 5 buy items,
Expand Down
28 changes: 7 additions & 21 deletions commands/local_cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,36 +57,22 @@ def run(results, cmdenv, tdb):
distances = { srcSystem: 0.0 }

# Calculate the bounding dimensions
srcX, srcY, srcZ = srcSystem.posX, srcSystem.posY, srcSystem.posZ
lySq = ly * ly

for destSys in tdb.systems():
distSq = (
(destSys.posX - srcX) ** 2 +
(destSys.posY - srcY) ** 2 +
(destSys.posZ - srcZ) ** 2
)
if distSq <= lySq and destSys is not srcSystem:
distances[destSys] = math.sqrt(distSq)
for destSys, dist in tdb.genSystemsInRange(srcSystem, ly):
distances[destSys] = dist

showStations = cmdenv.detail
if showStations:
stationIDs = ",".join([
",".join(str(stn.ID) for stn in sys.stations)
for sys in distances.keys()
if sys.stations
])
stmt = """
SELECT si.station_id,
JULIANDAY('NOW') - JULIANDAY(MIN(si.modified))
FROM StationItem AS si
WHERE si.station_id IN ({})
GROUP BY 1
""".format(stationIDs)
"""
cmdenv.DEBUG0("Fetching ages: {}", stmt)
ages = {}
for ID, age in tdb.query(stmt):
ages[ID] = age
ages = {
ID: age
for ID, age in tdb.query(stmt)
}

padSize = cmdenv.padSize
for (system, dist) in sorted(distances.items(), key=lambda x: x[1]):
Expand Down
4 changes: 2 additions & 2 deletions commands/nav_cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,10 @@ def run(results, cmdenv, tdb):
ages[ID] = age

for (jumpSys, dist) in route:
jumpLy = math.sqrt(lastSys.distToSq(jumpSys))
jumpLy = lastSys.distanceTo(jumpSys)
totalLy += jumpLy
if cmdenv.detail:
dirLy = math.sqrt(jumpSys.distToSq(dstSystem))
dirLy = jumpSys.distanceTo(dstSystem)
row = ResultRow(
action='Via',
system=jumpSys,
Expand Down
5 changes: 3 additions & 2 deletions commands/olddata_cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def run(results, cmdenv, tdb):
row.ls = "{:n}".format(ls)
else:
row.ls = "?"
row.dist2 = dist2
row.dist = dist2 ** 0.5
results.rows.append(row)

return results
Expand All @@ -164,7 +164,8 @@ def render(results, cmdenv, tdb):

if cmdenv.nearSystem:
rowFmt.addColumn('DistLy', '>', 6, '.2f',
key=lambda row: math.sqrt(row.dist2))
key=lambda row: row.dist
)

rowFmt.append(
ColumnFormat("Age/days", '>', '8', '.2f',
Expand Down
19 changes: 7 additions & 12 deletions commands/rares_cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,33 +104,28 @@ def run(results, cmdenv, tdb):
start = cmdenv.nearSystem
# Hoist the padSize parameter for convenience
padSize = cmdenv.padSize
# How far we're want to cast our net.
maxLy = float(cmdenv.maxLyPer or 0.)

# Start to build up the results data.
results.summary = ResultRow()
results.summary.near = start
results.summary.ly = cmdenv.maxLyPer

# The last step in calculating the distance between two
# points is to perform a square root. However, we can avoid
# the cost of doing this by squaring the distance we need
# to check and only 'rooting values that are <= to it.
maxLySq = cmdenv.maxLyPer ** 2
results.summary.ly = maxLy

# Look through the rares list.
for rare in tdb.rareItemByID.values():
if padSize: # do we care about pad size?
if not rare.station.checkPadSize(padSize):
continue
# Find the un-sqrt'd distance to the system.
distSq = start.distToSq(rare.station.system)
if maxLySq > 0: # do we have a limit on distance?
if distSq > maxLySq:
continue
dist = start.distanceTo(rare.station.system)
if maxLy > 0. and dist > maxLy:
continue

# Create a row for this item
row = ResultRow()
row.rare = rare
row.dist = math.sqrt(distSq)
row.dist = dist
results.rows.append(row)

# Was anything matched?
Expand Down
2 changes: 1 addition & 1 deletion edscupdate.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def get_cmdr(tdb):


def dist(x, y, z):
return math.sqrt((ox-x)**2 + (oy-y)**2 + (oz-z)**2)
return ((ox-x)**2 + (oy-y)**2 + (oz-z)**2) ** 0.5


def ischange(sysinfo):
Expand Down
2 changes: 1 addition & 1 deletion submit-distances.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def check_system(tdb, tdbSys, name):
return

print("KNOWN SYSTEM: {:.2f} ly".format(
math.sqrt(tdbSys.distToSq(system))
tdbSys.distanceTo(system)
))


Expand Down
83 changes: 54 additions & 29 deletions tradedb.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,27 +153,18 @@ def distToSq(self, other):
"""
Returns the square of the distance between two systems.
Optimization Note:
This function returns the SQUARE of the distance.
For any given pair of numbers (n, m), if n > m then n^2 > m^2
and if n < m then n^2 < m^2 and if n == m n^2 == m^2.
The final step in a distance calculation is a sqrt() function,
which is expensive.
So when you only need distances for comparative purposes, such
as comparing a set of points against a given distance, it is
much more efficient to square the comparitor and test it
against the un-rooted distances.
It is slightly cheaper to calculate the square of the
distance between two points, so when you are primarily
doing distance checks you can use this less expensive
distance query and only perform a sqrt (** 0.5) on the
distances that fall within your constraint.
Args:
other:
The other System to measure the distance between.
Returns:
Distance in light years (squared).
Distance in light years to the power of 2 (i.e. squared).
Example:
# Calculate which of [systems] is within 12 ly
Expand All @@ -183,19 +174,39 @@ def distToSq(self, other):
for sys in systems:
if sys.distToSq(target) <= maxLySq:
inRange.append(sys)
"""

dx2 = (self.posX - other.posX) ** 2
dy2 = (self.posY - other.posY) ** 2
dz2 = (self.posZ - other.posZ) ** 2

return (dx2 + dy2 + dz2)

def distanceTo(self, other):
"""
Returns the distance (in ly) between two systems.
NOTE: If you are primarily testing/comparing
distances, consider using "distToSq" for the test.
# Print the distance between two systems
print("{} -> {}: {}ly".format(
lhs.name(), rhs.name(),
math.sqrt(lhs.distToSq(rhs)),
Returns:
Distance in light years.
Example:
print("{} -> {}: {} ly".format(
lhs.name(), rhs.name(),
lhs.distanceTo(rhs),
))
"""

dx2 = (self.posX - other.posX) ** 2
dy2 = (self.posY - other.posY) ** 2
dz2 = (self.posZ - other.posZ) ** 2

return (dx2 + dy2 + dz2)
distSq = (dx2 + dy2 + dz2)

return distSq ** 0.5


def name(self):
return self.dbname
Expand Down Expand Up @@ -450,8 +461,11 @@ class TradeDB(object):
List of the .csv files
Static methods:
calculateDitance2(lx, ly, lz, rx, ry, rz)
Returns the square of the distance between two points.
calculateDistance2(lx, ly, lz, rx, ry, rz)
Returns the square of the distance in ly between two points.
calculateDistance(lx, ly, lz, rx, ry, rz)
Returns the distance in ly between two points.
listSearch(...)
Performs partial and ambiguity matching of a word from a list
Expand Down Expand Up @@ -540,12 +554,24 @@ def __init__(
@staticmethod
def calculateDistance2(lx, ly, lz, rx, ry, rz):
"""
Returns the square of the distance between two points
Returns the distance in ly between two points.
"""
dX = (lx - rx)
dY = (ly - ry)
dZ = (lz - rz)
distSq = (dX ** 2) + (dY ** 2) + (dZ ** 2)
return distSq

@staticmethod
def calculateDistance(lx, ly, lz, rx, ry, rz):
"""
Returns the distance in ly between two points.
"""
dX = (lx - rx)
dY = (ly - ry)
dZ = (lz - rz)
return (dX ** 2) + (dY ** 2) + (dZ ** 2)
distSq = (dX ** 2) + (dY ** 2) + (dZ ** 2)
return distSq ** 0.5

############################################################
# Access to the underlying database.
Expand Down Expand Up @@ -756,7 +782,7 @@ def genSystemsInRange(self, system, ly, includeSelf=False):
candidate:
System that was found,
distLy:
The distance in lightyears betwen system and candidate.
The distance in lightyears between system and candidate.
"""

if isinstance(system, Station):
Expand All @@ -778,7 +804,7 @@ def genSystemsInRange(self, system, ly, includeSelf=False):
if cand is not system:
cachedSystems.append((
cand,
math.sqrt(distSq)
distSq ** 0.5,
))

cachedSystems.sort(key=lambda ent: ent[1])
Expand Down Expand Up @@ -882,10 +908,9 @@ def getRoute(self, origin, dest, maxJumpLy, avoiding=[]):
except KeyError:
pass
distances[nSys] = (curSys, newDist)
weight = math.sqrt(curSys.distToSq(nSys))
weight = curSys.distanceTo(nSys)
nID = nSys.ID
# + 1 adds a penalty per jump
heapq.heappush(openSet, (newDist + weight + 1, newDist, nID))
heapq.heappush(openSet, (newDist + weight, newDist, nID))
if nID == destID:
distTo = newDist

Expand Down

0 comments on commit b81e6b0

Please sign in to comment.