From 52bb7f456990a047b0292198d40b3b29ebc0e8d9 Mon Sep 17 00:00:00 2001 From: Oliver 'kfsone' Smith Date: Wed, 1 May 2024 11:35:02 -0700 Subject: [PATCH] feat: Swap out homebrew progress bars for Rich - make bars hideable, - polish, - default the progress bar to visible, require show=False explicitly. - add more bar styles, - use some of the bars, --- tradedangerous/cache.py | 39 ++-- tradedangerous/misc/progress.py | 130 +++++++++---- tradedangerous/tradecalc.py | 315 ++++++++++++++++---------------- tradedangerous/transfers.py | 22 +-- 4 files changed, 279 insertions(+), 227 deletions(-) diff --git a/tradedangerous/cache.py b/tradedangerous/cache.py index afdeca55..751c6ed8 100644 --- a/tradedangerous/cache.py +++ b/tradedangerous/cache.py @@ -31,6 +31,7 @@ import typing from .tradeexcept import TradeException +from tradedangerous.misc.progress import Progress, CountingBar from . import corrections, utils from . import prices @@ -977,25 +978,31 @@ def buildCache(tdb, tdenv): tempDB.executescript(sqlScript) # import standard tables - for (importName, importTable) in tdb.importTables: - try: - processImportFile(tdenv, tempDB, Path(importName), importTable) - except FileNotFoundError: - tdenv.DEBUG0( - "WARNING: processImportFile found no {} file", importName - ) - except StopIteration: - tdenv.NOTE( - "{} exists but is empty. " - "Remove it or add the column definition line.", - importName - ) - - tempDB.commit() + with Progress(max_value=len(tdb.importTables) + 1, width=25, style=CountingBar) as prog: + for (importName, importTable) in tdb.importTables: + with prog.sub_task(description=importName, max_value=None): + prog.increment(value=1, description=importName) + try: + processImportFile(tdenv, tempDB, Path(importName), importTable) + except FileNotFoundError: + tdenv.DEBUG0( + "WARNING: processImportFile found no {} file", importName + ) + except StopIteration: + tdenv.NOTE( + "{} exists but is empty. " + "Remove it or add the column definition line.", + importName + ) + prog.increment(1) + + with prog.sub_task(description="Save DB"): + tempDB.commit() # Parse the prices file if pricesPath.exists(): - processPricesFile(tdenv, tempDB, pricesPath) + with Progress(max_value=None, width=25, prefix="Processing prices file"): + processPricesFile(tdenv, tempDB, pricesPath) else: tdenv.NOTE( "Missing \"{}\" file - no price data.", diff --git a/tradedangerous/misc/progress.py b/tradedangerous/misc/progress.py index 2b5dac8d..558a913b 100644 --- a/tradedangerous/misc/progress.py +++ b/tradedangerous/misc/progress.py @@ -1,52 +1,44 @@ from rich.progress import ( Progress as RichProgress, + TaskID, + ProgressColumn, BarColumn, DownloadColumn, MofNCompleteColumn, SpinnerColumn, TaskProgressColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn, TransferSpeedColumn ) +from contextlib import contextmanager -from typing import Optional +from typing import Iterable, Optional, Type class BarStyle: """ Base class for Progress bar style types. """ - def __init__(self, width: int=10, prefix: Optional[str] = None): - self.columns = [SpinnerColumn()] - if prefix is not None: - self.columns += [TextColumn(prefix)] - self.columns += [BarColumn(bar_width=width)] - self.columns += [TaskProgressColumn()] - self.columns += [TimeElapsedColumn()] + def __init__(self, width: int = 10, prefix: Optional[str] = None, *, add_columns: Optional[Iterable[ProgressColumn]]): + self.columns = [SpinnerColumn(), TextColumn(prefix), BarColumn(bar_width=width)] + if add_columns: + self.columns.extend(add_columns) class CountingBar(BarStyle): """ Creates a progress bar that is counting M/N items to completion. """ - def __init__(self, width: int=10, prefix: Optional[str] = None): - self.columns = [SpinnerColumn()] - if prefix is not None: - self.columns += [TextColumn(prefix)] - self.columns += [BarColumn(bar_width=width)] - self.columns += [MofNCompleteColumn()] - self.columns += [TimeElapsedColumn()] + def __init__(self, width: int = 10, prefix: Optional[str] = None): + my_columns = [MofNCompleteColumn(), TimeElapsedColumn()] + super().__init__(width, prefix, add_columns=my_columns) class DefaultBar(BarStyle): """ Creates a simple default progress bar with a percentage and time elapsed. """ - pass + def __init__(self, width: int = 10, prefix: Optional[str] = None): + my_columns = [TaskProgressColumn(), TimeElapsedColumn()] + super().__init__(width, prefix, add_columns=my_columns) class TransferBar(BarStyle): """ Creates a progress bar representing a data transfer, which shows the amount of data transferred, speed, and estimated time remaining. """ - def __init__(self, width: int=16, prefix: Optional[str] = None): - self.columns = ( - SpinnerColumn(), - TextColumn(prefix), - BarColumn(bar_width=width), - DownloadColumn(), - TransferSpeedColumn(), - TimeRemainingColumn(), - ) + def __init__(self, width: int = 16, prefix: Optional[str] = None): + my_columns = [DownloadColumn(), TransferSpeedColumn(), TimeRemainingColumn()] + super().__init__(width, prefix, add_columns=my_columns) class Progress: @@ -54,23 +46,40 @@ class Progress: Facade around the rich Progress bar system to help transition away from TD's original basic progress bar implementation. """ - def __init__(self, max_value: float, width: int, start: float = 0, prefix: str = "", *, style: BarStyle = DefaultBar) -> None: + def __init__(self, + max_value: Optional[float] = None, + width: Optional[int] = None, + start: float = 0, + prefix: Optional[str] = None, + *, + style: Optional[Type[BarStyle]] = None, + show: bool = True, + ) -> None: """ :param max_value: Last value we can reach (100%). :param width: How wide to make the bar itself. :param start: Override initial value to non-zero. :param prefix: Text to print between the spinner and the bar. + :param style: Bar-style factory to use for styling. + :param show: If False, disables the bar entirely. """ + self.show = bool(show) + if not show: + return + + if style is None: + style = DefaultBar + self.max_value = 0 if max_value is None else max(max_value, start) self.value = start - self.prefix = prefix - self.width = width - + self.prefix = prefix or "" + self.width = width or 25 # The 'Progress' itself is a view for displaying the progress of tasks. So we construct it # and then create a task for our job. + style_instance = style(width=self.width, prefix=self.prefix) self.progress = RichProgress( # What fields to display. - *style(width=self.width, prefix=self.prefix).columns, + *style_instance.columns, # Hide it once it's finished, update it for us, 4x a second transient=True, auto_refresh=True, refresh_per_second=4 ) @@ -84,33 +93,80 @@ def __init__(self, max_value: float, width: int, start: float = 0, prefix: str = self.progress.start() def __enter__(self): + """ Context manager. + + Example use: + + import time + import tradedangerous.progress + + # Progress(max_value=100, width=32, style=progress.CountingBar) + with progress.Progress(100, 32, style=progress.CountingBar) as prog: + for i in range(100): + prog.increment(1) + time.sleep(3) + """ return self def __exit__(self, *args, **kwargs): self.clear() - def increment(self, value: float, description: Optional[str] = None, *, postfix: str = "") -> None: + def increment(self, value: Optional[float] = None, description: Optional[str] = None, *, progress: Optional[float] = None) -> None: """ Increase the progress of the bar by a given amount. :param value: How much to increase the progress by. - :param postfix: [deprecated] text added after the bar :param description: If set, replaces the task description. + :param progress: Instead of increasing by value, set the absolute progress to this. """ + if not self.show: + return if description: - self.progress.update(self.task, description=description) - self.value += value # Update our internal count + self.prefix = description + self.progress.update(self.task, description=description, refresh=True) + + bump = False + if not value and progress is not None and self.value != progress: + self.value = progress + bump = True + elif value: + self.value += value # Update our internal count + bump = True + if self.value >= self.max_value: # Did we go past the end? Increase the end. self.max_value += value * 2 - self.progress.update(self.task, total=self.max_value) - if self.max_value > 0: - self.progress.update(self.task, completed=self.value) + self.progress.update(self.task, description=self.prefix, total=self.max_value) + bump = True + + if bump and self.max_value > 0: + self.progress.update(self.task, description=self.prefix, completed=self.value) def clear(self) -> None: """ Remove the current progress bar, if any. """ + # These two shouldn't happen separately, but incase someone tinkers, test each + # separately and shut them down. + if not self.show: + return + if self.task: self.progress.remove_task(self.task) self.task = None + if self.progress: self.progress.stop() self.progress = None + + @contextmanager + def sub_task(self, description: str, max_value: Optional[int] = None, width: int = 25): + if not self.show: + yield + return + task = self.progress.add_task(description, total=max_value, start=True, width=width) + try: + yield task + finally: + self.progress.remove_task(task) + + def update_task(self, task: TaskID, value: float, description: Optional[str] = None): + if self.show: + self.progress.update(task, value=value, description=description) diff --git a/tradedangerous/tradecalc.py b/tradedangerous/tradecalc.py index ac388b06..42aecebd 100644 --- a/tradedangerous/tradecalc.py +++ b/tradedangerous/tradecalc.py @@ -930,176 +930,173 @@ def station_iterator(srcStation): odyssey = odyssey, ) - prog = pbar.Progress(len(routes), 25) - connections = 0 - getSelling = self.stationsSelling.get - for route in routes: - if tdenv.progress: - prog.increment(1) - tdenv.DEBUG1("Route = {}", route.text(lambda x, y: y)) - - srcStation = route.lastStation - startCr = credits + int(route.gainCr * safetyMargin) - - srcSelling = getSelling(srcStation.ID, None) - srcSelling = tuple( - values for values in srcSelling - if values[1] <= startCr - ) - if not srcSelling: - tdenv.DEBUG1("Nothing sold/affordable - next.") - continue - - if goalSystem: - origSystem = route.firstSystem - srcSystem = srcStation.system - srcDistTo = srcSystem.distanceTo - goalDistTo = goalSystem.distanceTo - origDistTo = origSystem.distanceTo - srcGoalDist = srcDistTo(goalSystem) - srcOrigDist = srcDistTo(origSystem) - origGoalDist = origDistTo(goalSystem) - - if unique: - uniquePath = route.route - elif loopInt: - pos_from_end = 0 - loopInt - uniquePath = route.route[pos_from_end:-1] - - stations = (d for d in station_iterator(srcStation) - if (d.station != srcStation) and - (d.station.blackMarket == 'Y' if reqBlackMarket else True) and - (d.station not in uniquePath if uniquePath else True) and - (d.station in restrictStations if restrictStations else True) and - (d.station.dataAge and d.station.dataAge <= maxAge if maxAge else True) and - (((d.system is not srcSystem) if bool(tdenv.unique) else (d.system is goalSystem or d.distLy < srcGoalDist)) if goalSystem else True) - ) - - if tdenv.debug >= 1: + with pbar.Progress(max_value=len(routes), width=25, show=tdenv.progress) as prog: + connections = 0 + getSelling = self.stationsSelling.get + for route_no, route in enumerate(routes): + prog.increment(progress=route_no) + tdenv.DEBUG1("Route = {}", route.text(lambda x, y: y)) - def annotate(dest): - tdenv.DEBUG1( - "destSys {}, destStn {}, jumps {}, distLy {}", - dest.system.dbname, - dest.station.dbname, - "->".join(jump.text() for jump in dest.via), - dest.distLy - ) - return True - - stations = (d for d in stations if annotate(d)) - - for dest in stations: - dstStation = dest.station + srcStation = route.lastStation + startCr = credits + int(route.gainCr * safetyMargin) - connections += 1 - items = self.getTrades(srcStation, dstStation, srcSelling) - if not items: + srcSelling = getSelling(srcStation.ID, None) + srcSelling = tuple( + values for values in srcSelling + if values[1] <= startCr + ) + if not srcSelling: + tdenv.DEBUG1("Nothing sold/affordable - next.") continue - trade = fitFunction(items, startCr, capacity, maxUnits) - multiplier = 1.0 - # Calculate total K-lightseconds supercruise time. - # This will amortize for the start/end stations - dstSys = dest.system - if goalSystem and dstSys is not goalSystem: - dstGoalDist = goalDistTo(dstSys) - # Biggest reward for shortening distance to goal - score = 5000 * origGoalDist / dstGoalDist - # bias towards bigger reductions - score += 50 * srcGoalDist / dstGoalDist - # discourage moving back towards origin - if dstSys is not origSystem: - score += 10 * (origDistTo(dstSys) - srcOrigDist) - # Gain per unit pays a small part - score += (trade.gainCr / trade.units) / 25 - else: - score = trade.gainCr - if lsPenalty: - # [kfsone] Only want 1dp + if goalSystem: + origSystem = route.firstSystem + srcSystem = srcStation.system + srcDistTo = srcSystem.distanceTo + goalDistTo = goalSystem.distanceTo + origDistTo = origSystem.distanceTo + srcGoalDist = srcDistTo(goalSystem) + srcOrigDist = srcDistTo(origSystem) + origGoalDist = origDistTo(goalSystem) + + if unique: + uniquePath = route.route + elif loopInt: + pos_from_end = 0 - loopInt + uniquePath = route.route[pos_from_end:-1] + + stations = (d for d in station_iterator(srcStation) + if (d.station != srcStation) and + (d.station.blackMarket == 'Y' if reqBlackMarket else True) and + (d.station not in uniquePath if uniquePath else True) and + (d.station in restrictStations if restrictStations else True) and + (d.station.dataAge and d.station.dataAge <= maxAge if maxAge else True) and + (((d.system is not srcSystem) if bool(tdenv.unique) else (d.system is goalSystem or d.distLy < srcGoalDist)) if goalSystem else True) + ) + + if tdenv.debug >= 1: - cruiseKls = int(dstStation.lsFromStar / 100) / 10 - # Produce a curve that favors distances under 1kls - # positively, starts to penalize distances over 1k, - # and after 4kls starts to penalize aggressively - # http://goo.gl/Otj2XP + def annotate(dest): + tdenv.DEBUG1( + "destSys {}, destStn {}, jumps {}, distLy {}", + dest.system.dbname, + dest.station.dbname, + "->".join(jump.text() for jump in dest.via), + dest.distLy + ) + return True - # [eyeonus] As aadler pointed out, this goes into negative - # numbers, which causes problems. - # penalty = ((cruiseKls ** 2) - cruiseKls) / 3 - # penalty *= lsPenalty - # multiplier *= (1 - penalty) + stations = (d for d in stations if annotate(d)) + + for dest in stations: + dstStation = dest.station - # [eyeonus]: - # (Keep in mind all this ignores values of x<0.) - # The sigmoid: (1-(25(x-1))/(1+abs(25(x-1))))/4 - # ranges between 0.5 and 0 with a drop around x=1, - # which makes it great for giving a boost to distances < 1Kls. - # - # The sigmoid: (-1-(50(x-4))/(1+abs(50(x-4))))/4 - # ranges between 0 and -0.5 with a drop around x=4, - # making it great for penalizing distances > 4Kls. - # - # The curve: (-1+1/(x+1)^((x+1)/4))/2 - # ranges between 0 and -0.5 in a smooth arc, - # which will be used for making distances - # closer to 4Kls get a slightly higher penalty - # then distances closer to 1Kls. - # - # Adding the three together creates a doubly-kinked curve - # that ranges from ~0.5 to -1.0, with drops around x=1 and x=4, - # which closely matches ksfone's intention without going into - # negative numbers and causing problems when we add it to - # the multiplier variable. ( 1 + -1 = 0 ) - # - # You can see a graph of the formula here: - # https://goo.gl/sn1PqQ - # NOTE: The black curve is at a penalty of 0%, - # the red curve at a penalty of 100%, with intermediates at - # 25%, 50%, and 75%. - # The other colored lines show the penalty curves individually - # and the teal composite of all three. + connections += 1 + items = self.getTrades(srcStation, dstStation, srcSelling) + if not items: + continue + trade = fitFunction(items, startCr, capacity, maxUnits) - def sigmoid(x): - return x / (1 + abs(x)) + multiplier = 1.0 + # Calculate total K-lightseconds supercruise time. + # This will amortize for the start/end stations + dstSys = dest.system + if goalSystem and dstSys is not goalSystem: + dstGoalDist = goalDistTo(dstSys) + # Biggest reward for shortening distance to goal + score = 5000 * origGoalDist / dstGoalDist + # bias towards bigger reductions + score += 50 * srcGoalDist / dstGoalDist + # discourage moving back towards origin + if dstSys is not origSystem: + score += 10 * (origDistTo(dstSys) - srcOrigDist) + # Gain per unit pays a small part + score += (trade.gainCr / trade.units) / 25 + else: + score = trade.gainCr + if lsPenalty: + # [kfsone] Only want 1dp + + cruiseKls = int(dstStation.lsFromStar / 100) / 10 + # Produce a curve that favors distances under 1kls + # positively, starts to penalize distances over 1k, + # and after 4kls starts to penalize aggressively + # http://goo.gl/Otj2XP + + # [eyeonus] As aadler pointed out, this goes into negative + # numbers, which causes problems. + # penalty = ((cruiseKls ** 2) - cruiseKls) / 3 + # penalty *= lsPenalty + # multiplier *= (1 - penalty) + + # [eyeonus]: + # (Keep in mind all this ignores values of x<0.) + # The sigmoid: (1-(25(x-1))/(1+abs(25(x-1))))/4 + # ranges between 0.5 and 0 with a drop around x=1, + # which makes it great for giving a boost to distances < 1Kls. + # + # The sigmoid: (-1-(50(x-4))/(1+abs(50(x-4))))/4 + # ranges between 0 and -0.5 with a drop around x=4, + # making it great for penalizing distances > 4Kls. + # + # The curve: (-1+1/(x+1)^((x+1)/4))/2 + # ranges between 0 and -0.5 in a smooth arc, + # which will be used for making distances + # closer to 4Kls get a slightly higher penalty + # then distances closer to 1Kls. + # + # Adding the three together creates a doubly-kinked curve + # that ranges from ~0.5 to -1.0, with drops around x=1 and x=4, + # which closely matches ksfone's intention without going into + # negative numbers and causing problems when we add it to + # the multiplier variable. ( 1 + -1 = 0 ) + # + # You can see a graph of the formula here: + # https://goo.gl/sn1PqQ + # NOTE: The black curve is at a penalty of 0%, + # the red curve at a penalty of 100%, with intermediates at + # 25%, 50%, and 75%. + # The other colored lines show the penalty curves individually + # and the teal composite of all three. + + def sigmoid(x): + return x / (1 + abs(x)) + + boost = (1 - sigmoid(25 * (cruiseKls - 1))) / 4 + drop = (-1 - sigmoid(50 * (cruiseKls - 4))) / 4 + try: + penalty = (-1 + 1 / (cruiseKls + 1) ** ((cruiseKls + 1) / 4)) / 2 + except OverflowError: + penalty = -0.5 + + multiplier += (penalty + boost + drop) * lsPenalty - boost = (1 - sigmoid(25 * (cruiseKls - 1))) / 4 - drop = (-1 - sigmoid(50 * (cruiseKls - 4))) / 4 - try: - penalty = (-1 + 1 / (cruiseKls + 1) ** ((cruiseKls + 1) / 4)) / 2 - except OverflowError: - penalty = -0.5 + score *= multiplier - multiplier += (penalty + boost + drop) * lsPenalty - - score *= multiplier - - dstID = dstStation.ID - try: - # See if there is already a candidate for this destination - btd = bestToDest[dstID] - except KeyError: - # No existing candidate, we win by default - pass - else: - bestRoute = btd[1] - bestScore = btd[5] - # Check if it is a better option than we just produced - bestTradeScore = bestRoute.score + bestScore - newTradeScore = route.score + score - if bestTradeScore > newTradeScore: - continue - if bestTradeScore == newTradeScore: - bestLy = btd[4] - if bestLy <= dest.distLy: + dstID = dstStation.ID + try: + # See if there is already a candidate for this destination + btd = bestToDest[dstID] + except KeyError: + # No existing candidate, we win by default + pass + else: + bestRoute = btd[1] + bestScore = btd[5] + # Check if it is a better option than we just produced + bestTradeScore = bestRoute.score + bestScore + newTradeScore = route.score + score + if bestTradeScore > newTradeScore: continue - - bestToDest[dstID] = ( - dstStation, route, trade, dest.via, dest.distLy, score - ) - - prog.clear() - + if bestTradeScore == newTradeScore: + bestLy = btd[4] + if bestLy <= dest.distLy: + continue + + bestToDest[dstID] = ( + dstStation, route, trade, dest.via, dest.distLy, score + ) + if connections == 0: raise NoHopsError( "No destinations could be reached within the constraints." diff --git a/tradedangerous/transfers.py b/tradedangerous/transfers.py index c61cccc6..f343bb4a 100644 --- a/tradedangerous/transfers.py +++ b/tradedangerous/transfers.py @@ -17,7 +17,7 @@ if typing.TYPE_CHECKING: import os # for PathLike from .tradeenv import TradeEnv - from typing import Optional, Union + from typing import Callable, Optional, Union ###################################################################### @@ -51,7 +51,7 @@ def download( localFile: os.PathLike, headers: Optional[dict] = None, backup: bool = False, - shebang: Optional[str] = None, + shebang: Optional[Callable] = None, chunkSize: int = 4096, timeout: int = 90, *, @@ -106,19 +106,14 @@ def download( tdenv.NOTE("Downloading {} {}ed data", transfer, encoding) tdenv.DEBUG0(str(req.headers).replace("{", "{{").replace("}", "}}")) - # Figure out how much data we have - if not tdenv.quiet: - filename = get_filename_from_url(url) - progBar = pbar.Progress(length, 20, prefix=filename, style=pbar.TransferBar) - else: - progBar = None - actPath = Path(localFile) fs.ensurefolder(tdenv.tmpDir) tmpPath = Path(tdenv.tmpDir, "{}.dl".format(actPath.name)) fetched = 0 - with tmpPath.open("wb") as fh: + started = time.time() + filename = get_filename_from_url(url) + with pbar.Progress(max_value=length, width=25, prefix=filename, style=pbar.CountingBar, show=not tdenv.quiet) as prog, tmpPath.open("wb") as fh: for data in req.iter_content(chunk_size=chunkSize): fh.write(data) fetched += len(data) @@ -127,13 +122,10 @@ def download( tdenv.DEBUG0("Checking shebang of {}", bangLine) shebang(bangLine) shebang = None - if progBar: - progBar.increment(len(data)) + if prog: + prog.increment(len(data)) tdenv.DEBUG0("End of data") - if progBar: - progBar.clear() - if not tdenv.quiet: elapsed = (time.time() - started) or 1 tdenv.NOTE(