diff --git a/tradedb.py b/tradedb.py index fe169e7e..d8dd8338 100644 --- a/tradedb.py +++ b/tradedb.py @@ -17,6 +17,8 @@ import sys from queue import Queue # Because we're British. from collections import namedtuple +import itertools +import math ###################################################################### # Classes @@ -62,13 +64,14 @@ class System(object): # TODO: Build the links from an SQL query, it'll save a lot of # expensive python dictionary lookups. - def __init__(self, ID, system, posX, posY, posZ): - self.ID, self.system, self.posX, self.posY, self.posZ = ID, system, posX, posY, posZ + def __init__(self, ID, name, posX, posY, posZ): + self.ID, self.dbname, self.posX, self.posY, self.posZ = ID, name, posX, posY, posZ self.links = {} self.stations = [] - def addLink(self, dest, dist): - self.links[dest] = dist + @staticmethod + def linkSystems(lhs, rhs, distSq): + lhs.links[rhs] = rhs.links[lhs] = math.sqrt(distSq) def links(self): return list(self.links.keys()) @@ -78,13 +81,13 @@ def addStation(self, station): self.stations.append(station) def name(self): - return self.system.upper() + return self.dbname.upper() def str(self): - return self.system + return self.dbname def __repr__(self): - return "".format(self.ID, self.system, self.posX, self.posY, self.posZ) + return "".format(self.ID, self.dbname, self.posX, self.posY, self.posZ) class Station(object): @@ -93,14 +96,14 @@ class Station(object): opportunities it presents. """ - def __init__(self, ID, system, station, lsFromStar=0.0): - self.ID, self.system, self.station, self.lsFromStar = ID, system, station, lsFromStar + def __init__(self, ID, system, name, lsFromStar=0.0): + self.ID, self.system, self.dbname, self.lsFromStar = ID, system, name, lsFromStar self.trades = {} self.stations = [] system.addStation(self) def name(self): - return self.station + return self.dbname def addTrade(self, dest, item, itemID, costCr, gainCr): """ @@ -115,15 +118,6 @@ def addTrade(self, dest, item, itemID, costCr, gainCr): trade = Trade(item, itemID, costCr, gainCr) self.trades[dstID].append(trade) - def organizeTrades(self): - """ - Process the trades-to-destination lists: sort the list into by-gain order. - """ - # TODO: Read them from the DB in this order. - for tradeList in self.trades.values(): - # sort the list in descending gain order - so the mostprofitable item is listed first. - tradeList.sort(key=lambda trade: trade.gainCr, reverse=True) - def getDestinations(self, maxJumps=None, maxLyPer=None, avoiding=None): """ Gets a list of the Station destinations that can be reached @@ -200,15 +194,15 @@ def getDestinations(self, maxJumps=None, maxLyPer=None, avoiding=None): return destStations def name(self): - return self.station + return self.dbname def str(self): - return '%s %s' % (self.system.name(), self.station) + return '%s %s' % (self.system.name(), self.dbname) def __repr__(self): - return ''.format(self.ID, self.system.name(), self.name()) + return ''.format(self.ID, self.system.name(), self.dbname, self.lsFromStar) -class Ship(namedtuple('Ship', [ 'name', 'capacity', 'maxJump', 'maxJumpFull', 'stations' ])): +class Ship(namedtuple('Ship', [ 'ID', 'name', 'capacity', 'mass', 'driveRating', 'maxLyEmpty', 'maxLyFull', 'maxSpeed', 'boostSpeed', 'stations' ])): pass class TradeDB(object): @@ -258,56 +252,155 @@ def __init__(self, path='.\\TradeDangerous.sq3', debug=0): self.load() + def _load_systems(self): + """ + Initial load the (raw) list of systems. + If you have previously loaded Systems, this will orphan the old System objects. + """ + stmt = """ + SELECT system_id, name, pos_x, pos_y, pos_z + FROM System + """ + self.cur.execute(stmt) + systemByID, systemByName = {}, {} + for (ID, name, posX, posY, posZ) in self.cur: + systemByID[ID] = systemByName[name] = System(ID, name, posX, posY, posZ) + + self.systemByID, self.systemByName = systemByID, systemByName + if self.debug > 1: print("# Loaded %d Systems" % len(systemByID)) + + def _load_stations(self): + """ + Populate the Station list. + Station constructor automatically adds itself to the System object. + If you have previously loaded Stations, this will orphan the old objects. + """ + stmt = """ + SELECT station_id, system_id, name, ls_from_star + FROM Station + """ + self.cur.execute(stmt) + stationByID, stationByName = {}, {} + systemByID = self.systemByID + for (ID, systemID, name, lsFromStar) in self.cur: + stationByID[ID] = stationByName[name] = Station(ID, systemByID[systemID], name, lsFromStar) + + self.stationByID, self.stationByName = stationByID, stationByName + if self.debug > 1: print("# Loaded %d Stations" % len(stationByID)) + + def _load_ships(self): + """ + Populate the Ship list. + If you have previously loaded Ships, this will orphan the old objects. + """ + stmt = """ + SELECT ship_id, name, capacity, mass, drive_rating, max_ly_empty, max_ly_full, max_speed, boost_speed + FROM Ship + """ + self.cur.execute(stmt) + self.shipByID = { row[0]: Ship(*row, stations=[]) for row in self.cur } + + if self.debug > 1: print("# Loaded %d Ships" % len(self.shipByID)) + + def _load_items(self): + """ + Populate the Item list. + If you have previously loaded Items, this will orphan the old objects. + """ + stmt = """ + SELECT item_id, name + FROM Item + """ + self.cur.execute(stmt) + itemByID, itemByName = {}, {} + for (ID, name) in self.cur: + itemByID[ID], itemByName[name] = name, ID + + self.itemByID, self.itemByName = itemByID, itemByName + if self.debug > 1: print("# Loaded %d Items" % len(itemByID)) + + def build_links(self, longestJumpLy): + """ + Populate the list of reachable systems for every star system. + + Not every system can reach every other, and we use the longest jump + that can be made by a ship to limit how many connections we consider + to be "links". + """ + + longestJumpSq = longestJumpLy ** 2 # So we don't have to sqrt every distance + + # Generate a series of symmetric pairs (A->B, A->C, A->D, B->C, B->D, C->D) + # so we only calculate each distance once, and then add a link each way. + # (A->B distance populates A->B and B->A, etc) + numLinks = 0 + for (lhs, rhs) in itertools.combinations(self.systemByID.values(), 2): + dX, dY, dZ = rhs.posX - lhs.posX, rhs.posY - lhs.posY, rhs.posZ - lhs.posZ + distSq = (dX * dX) + (dY * dY) + (dZ * dZ) + if distSq <= longestJumpSq: + System.linkSystems(lhs, rhs, distSq) + numLinks += 1 + + if self.debug > 2: print("# Number of links between systems: %d" % numLinks) + + def load_trades(self): + """ + Load the prices records which indicate that one item sells an item that + another station buys for more, indicating a profitable trade. + Ignore items that have a ui_order of 0 (my way of indicating the item is + either unavailable or black market). + NOTE: Trades MUST be loaded such that they are populated into the + lists in descending order of profit (highest profit first) + """ + stmt = """ + SELECT src.station_id, dst.station_id + , src.item_id + , src.buy_from + , dst.sell_to - src.buy_from AS profit + FROM Price AS src INNER JOIN Price as dst + ON src.item_id = dst.item_id + WHERE src.buy_from > 0 + AND profit > 0 + AND src.ui_order > 0 + AND dst.ui_order > 0 + ORDER BY profit DESC + """ + self.cur.execute(stmt) + stations, items = self.stationByID, self.itemByID + for (srcStnID, dstStnID, itemID, srcCostCr, profitCr) in self.cur: + srcStn, dstStn, item = stations[srcStnID], stations[dstStnID], items[itemID] + srcStn.addTrade(dstStn, item, itemID, srcCostCr, profitCr) + def load(self): - """ Populate/re-populate this instance with data from the TradeDB layer. """ - # Create a cursor. - cur = self.conn.cursor() - - # Fetch a list of systems. - cur.execute('SELECT system FROM Stations GROUP BY system') - systems = self.systems = { row[0]: System(row[0]) for row in cur } - - # Fetch a list of links between systems. - # TODO: Store positions, calculate distances on demand - cur.execute("""SELECT frmSys.system, toSys.system, Links.distLy - FROM Stations AS frmSys, Links, Stations as toSys - WHERE frmSys.ID = Links.from AND toSys.ID = Links.to""") - for (srcSysID, dstSysID, distLy) in cur: - srcSys, dstSys = systems[srcSysID], systems[dstSysID] - srcSys.addLink(dstSys, float(distLy)) - - # Fetch the list of stations - cur.execute('SELECT id, system, station FROM Stations') - # Station lookup by ID - self.stations = { row[0]: Station(row[0], self.systems[row[1]], row[2]) for row in cur } - # StationID lookup by System Name - self.systemIDs = { value.system.str().upper(): key for (key, value) in self.stations.items() } - # StationID lookup by Station Name - self.stationIDs = { value.station.upper(): key for (key, value) in self.stations.items() } - - # Populate 'items' from the database - cur.execute('SELECT id, item FROM Items') - self.items = { row[0]: row[1] for row in cur } - self.itemIDs = { name: itemID for (itemID, name) in self.items.items() } - - stations, items = self.stations, self.items - - # Populate the station list with the profitable trades between stations - # Ignore items that have a ui_order of 0 in the prices table (my way of marking an item as defunct or illegal) - cur.execute('SELECT src.station_id, dst.station_id, src.item_id, src.buy_cr, dst.sell_cr' - ' FROM Prices AS src INNER JOIN Prices AS dst ON src.item_id = dst.item_id' - ' WHERE src.buy_cr > 0 AND dst.sell_cr > src.buy_cr' - ' AND src.ui_order > 0 AND dst.ui_order > 0' - ) - for (srcID, dstID, itemID, srcCostCr, dstValueCr) in cur: - srcStn = stations[srcID] - dstStn = stations[dstID] - item = items[itemID] - srcStn.addTrade(dstStn, item, itemID, srcCostCr, dstValueCr - srcCostCr) - - # Post-process the trades and sort them into whatever order we want them in. - for station in stations.values(): - station.organizeTrades() + """ + Populate/re-populate this instance of TradeDB with data. + WARNING: This will orphan existing records you have + taken references to: + tdb.load() + x = tdb.getStation("Aulin") + tdb.load() # x now points to an orphan Aulin + """ + + self.cur = self.conn.cursor() + + # Load raw tables. Stations will be linked to systems, but nothing else. + # TODO: Make station -> system link a post-load action. + self._load_systems() + self._load_stations() + self._load_ships() + self._load_items() + + systems, stations, ships, items = self.systemByID, self.stationByID, self.shipByID, self.itemByID + + # Calculate the maximum distance anyone can jump so we can constrain + # the maximum "link" between any two stars. + longestJumper = max(ships.values(), key=lambda ship: ship.maxLyEmpty) + self.maxSystemLinkLy = longestJumper.maxLyEmpty + 0.01 + if self.debug > 2: print("# Max ship jump distance: %s @ %f" % (longestJumper.name, self.maxSystemLinkLy)) + + self.build_links(self.maxSystemLinkLy) + + self.load_trades() # In debug mode, check that everything looks sane. if self.debug: @@ -315,17 +408,8 @@ def load(self): def _validate(self): # Check that things correctly reference themselves. - for (stnID, stn) in self.stations.items(): - if self.stations[stn.ID] != stn: - raise ValueError("Station not pointing to self correctly" % stn.station) - for (stnName, stnID) in self.stationIDs.items(): - if self.stations[stnID].station.upper() != stnName: - raise ValueError("Station name not pointing to self correctly" % stnName) - for (itemID, item) in self.items.items(): - if self.itemIDs[item] != itemID: - raise ValueError("Item %s not pointing to itself correctly" % item, itemID, item, self.itemIDs[item]) # Check that system links are bi-directional - for (name, sys) in self.systems.items(): + for (name, sys) in self.systemByName.items(): if not sys.links: raise ValueError("System %s has no links" % name) if sys in sys.links: @@ -336,15 +420,14 @@ def _validate(self): if not sys in link.links: raise ValueError("System %s does not have a reciprocal link in %s's links" % (name, link.str())) - def getSystem(self, name): + def getSystem(self, key): """ Look up a System object by it's name. """ - if isinstance(name, System): + if isinstance(key, System): return name - if isinstance(name, Station): + if isinstance(key, Station): return name.system - system = TradeDB.list_search("System", name, self.systems.keys()) - return self.systems[system] + return TradeDB.list_search("System", name, self.systems.values(), key=lambda system: system.name) def getStation(self, name): """ Look up a Station object by it's name or system. """ @@ -358,8 +441,7 @@ def getStation(self, name): stationID, station, systemID, system = None, None, None, None try: - systemID = TradeDB.list_search("System", name, self.systems.keys()) - system = self.systems[systemID] + system = TradeDB.list_search("System", name, self.systems.values(), key=lambda system: system.name) except LookupError: pass try: