Skip to content

Commit

Permalink
Experimental numpy usage
Browse files Browse the repository at this point in the history
  • Loading branch information
kfsone committed May 1, 2015
1 parent 5571916 commit 5c99213
Showing 1 changed file with 34 additions and 2 deletions.
36 changes: 34 additions & 2 deletions tradedb.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,19 @@
import sqlite3
import sys

try:
import numpy
import numpy.linalg
haveNumpy = True
except ImportError:
class numpy(object):
array = False
float32 = False
ascontiguousarray = False
class linalg(object):
norm = False
haveNumpy = False

locale.setlocale(locale.LC_ALL, '')


Expand Down Expand Up @@ -132,7 +145,7 @@ class System(object):

__slots__ = (
'ID',
'dbname', 'posX', 'posY', 'posZ', 'stations',
'dbname', 'posX', 'posY', 'posZ', 'pos', 'stations',
'addedID',
'_rangeCache'
)
Expand All @@ -145,10 +158,16 @@ def __init__(self):
self.systems = []
self.probedLy = 0.

def __init__(self, ID, dbname, posX, posY, posZ, addedID):
def __init__(
self, ID, dbname, posX, posY, posZ, addedID,
ary=numpy.array,
nptype=numpy.float32,
):
self.ID = ID
self.dbname = dbname
self.posX, self.posY, self.posZ = posX, posY, posZ
if haveNumpy:
self.pos = ary([posX, posY, posZ], nptype)
self.addedID = addedID or 0
self.stations = ()
self._rangeCache = None
Expand Down Expand Up @@ -211,6 +230,19 @@ def distanceTo(self, other):
(self.posZ - other.posZ) ** 2
) ** 0.5 # fast sqrt

if haveNumpy:
def all_distances(
self, iterable,
ary=numpy.ascontiguousarray, norm=numpy.linalg.norm,
):
"""
Takes a list of systems and returns their distances from this system.
"""
return numpy.linalg.norm(
ary([s.pos for s in iterable]) - self.pos,
ord=2, axis=1.
)

def getStation(self, stationName):
"""
Quick case-insensitive lookup of a station name within the
Expand Down

0 comments on commit 5c99213

Please sign in to comment.