diff --git a/tradedb.py b/tradedb.py index 380a63e4..2e8dddbb 100644 --- a/tradedb.py +++ b/tradedb.py @@ -62,6 +62,19 @@ import sqlite3 import sys +try: + import numpy + import numpy.linalg + haveNumpy = True +except ImportError: + class numpy(object): + array = False + float32 = False + ascontiguousarray = False + class linalg(object): + norm = False + haveNumpy = False + locale.setlocale(locale.LC_ALL, '') @@ -132,7 +145,7 @@ class System(object): __slots__ = ( 'ID', - 'dbname', 'posX', 'posY', 'posZ', 'stations', + 'dbname', 'posX', 'posY', 'posZ', 'pos', 'stations', 'addedID', '_rangeCache' ) @@ -145,10 +158,16 @@ def __init__(self): self.systems = [] self.probedLy = 0. - def __init__(self, ID, dbname, posX, posY, posZ, addedID): + def __init__( + self, ID, dbname, posX, posY, posZ, addedID, + ary=numpy.array, + nptype=numpy.float32, + ): self.ID = ID self.dbname = dbname self.posX, self.posY, self.posZ = posX, posY, posZ + if haveNumpy: + self.pos = ary([posX, posY, posZ], nptype) self.addedID = addedID or 0 self.stations = () self._rangeCache = None @@ -211,6 +230,19 @@ def distanceTo(self, other): (self.posZ - other.posZ) ** 2 ) ** 0.5 # fast sqrt + if haveNumpy: + def all_distances( + self, iterable, + ary=numpy.ascontiguousarray, norm=numpy.linalg.norm, + ): + """ + Takes a list of systems and returns their distances from this system. + """ + return numpy.linalg.norm( + ary([s.pos for s in iterable]) - self.pos, + ord=2, axis=1. + ) + def getStation(self, stationName): """ Quick case-insensitive lookup of a station name within the