Skip to content

Commit

Permalink
Adding LoadingBar.track
Browse files Browse the repository at this point in the history
Fix #964
  • Loading branch information
Yomguithereal committed Sep 6, 2024
1 parent 6989491 commit b4eb647
Showing 1 changed file with 17 additions and 3 deletions.
20 changes: 17 additions & 3 deletions minet/cli/loading_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#
# Various loading bar utilities used by minet CLI.
#
from typing import Optional, Iterable
from typing import Optional, Iterable, TypeVar, Iterator
from minet.types import TypedDict, NotRequired

from contextlib import contextmanager
Expand All @@ -26,6 +26,8 @@
from minet.utils import message_flatmap
from minet.cli.console import console

T = TypeVar("T")


class CautiousBarColumn(BarColumn):
def render(self, task: Task):
Expand Down Expand Up @@ -410,6 +412,7 @@ def __exit__(self, exc_type, exc_value, exc_traceback):
if not self.known_total:
self.bar_column.style = style
else:
assert self.bar_column is not None
self.bar_column.pulse_style = "success"
self.bar_column.style = "success"

Expand Down Expand Up @@ -443,6 +446,11 @@ def step(self, item=None, count=1, index=None, catch=None, sub_total=None):
if not interrupted:
self.advance(count)

def track(self, iterable: Iterable[T]) -> Iterator[T]:
for item in iterable:
with self.step(item):
yield item

@contextmanager
def nested_step(self, count=1):
assert self.nested
Expand All @@ -461,15 +469,19 @@ def advance(self, count=1):
self.progress.update(self.task_id, advance=count)

def nested_advance(self, count=1):
assert self.sub_progress is not None and self.sub_task_id is not None

self.sub_total_sum += count
self.sub_progress.update(
self.sub_task_id, advance=count, sub_total_sum=self.sub_total_sum
)

def nested_reset(self):
assert self.sub_progress is not None and self.sub_task_id is not None
self.sub_progress.reset(self.sub_task_id)

def __refresh_stats(self):
assert self.stats_progress is not None
self.stats_progress.update(self.stats_task_id, stats=self.stats)

if not self.simple and not self.stats_are_shown:
Expand All @@ -488,14 +500,15 @@ def set_total(self, total: Optional[int] = None):
self.known_total = total is not None

def set_sub_total(self, total: Optional[int] = None):
assert self.sub_progress is not None and self.sub_task_id is not None
self.sub_progress.update(self.sub_task_id, total=total)

def set_label(self, label: str):
assert self.label_progress is not None
assert self.label_progress is not None and self.label_progress_task_id
self.label_progress.update(self.label_progress_task_id, description=label)

# TODO: factorize
def set_stat(self, name: str, count: Optional[int], style: str = None):
def set_stat(self, name: str, count: Optional[int], style: Optional[str] = None):
assert self.stats is not None

if name not in self.stats:
Expand Down Expand Up @@ -534,6 +547,7 @@ def update(self, count=None, sub_title=None, sub_count=None, label=None, **field
self.nested_advance(sub_count)

if sub_title is not None:
assert self.sub_progress is not None and self.sub_task_id is not None
self.sub_progress.update(self.sub_task_id, description=sub_title)

if label is not None:
Expand Down

0 comments on commit b4eb647

Please sign in to comment.