Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: progress bar display #420

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions pymc_extras/inference/pathfinder/pathfinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
from pytensor.tensor import TensorConstant, TensorVariable
from rich.console import Console, Group
from rich.padding import Padding
from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn
from rich.table import Table
from rich.text import Text

Expand Down Expand Up @@ -1395,7 +1396,7 @@ def _get_status_warning(mpr: MultiPathfinderResult) -> list[str]:

path_status_message = {
PathStatus.ELBO_ARGMAX_AT_ZERO: "ELBO argmax at zero refers to the first iteration during LBFGS. A high occurrence suggests the model's default initial point + jitter is may be too close to the mean posterior and a poor exploration of the parameter space. Consider increasing jitter if this occurence is high relative to the number of paths.",
PathStatus.INVALID_LOGP: "Invalid logP values occur when a path's logP values are not finite. The failed path is not included in samples when importance sampling is used. Consider reparameterizing the model or adjusting the pathfinder arguments if this occurence is high relative to the number of paths.",
PathStatus.ELBO_ARGMAX_AT_ZERO: "ELBO argmax at zero refers to the first iteration during LBFGS. A high occurrence suggests the model's default initial point + jitter values are concentrated in high-density regions in the target distribution and may result in poor exploration of the parameter space. Consider increasing jitter if this occurrence is high relative to the number of paths.",
PathStatus.INVALID_LOGQ: "Invalid logQ values occur when a path's logQ values are not finite. The failed path is not included in samples when importance sampling is used. Consider reparameterizing the model or adjusting the pathfinder arguments if this occurence is high relative to the number of paths.",
}

Expand Down Expand Up @@ -1521,12 +1522,20 @@ def multipath_pathfinder(
results = []
compute_start = time.time()
try:
with CustomProgress(
desc = f"Paths Complete: {{path_idx}}/{num_paths}"
progress = CustomProgress(
"[progress.description]{task.description}",
BarColumn(),
"[progress.percentage]{task.percentage:>3.0f}%",
TimeRemainingColumn(),
TextColumn("/"),
TimeElapsedColumn(),
console=Console(theme=default_progress_theme),
disable=not progressbar,
) as progress:
task = progress.add_task("Fitting", total=num_paths)
for result in generator:
)
with progress:
task = progress.add_task(desc.format(path_idx=0), completed=0, total=num_paths)
for path_idx, result in enumerate(generator, start=1):
try:
if isinstance(result, Exception):
raise result
Expand All @@ -1552,7 +1561,14 @@ def multipath_pathfinder(
lbfgs_status=LBFGSStatus.LBFGS_FAILED,
)
)
progress.update(task, advance=1)
finally:
# TODO: display LBFGS and Path Status in real time
progress.update(
task,
description=desc.format(path_idx=path_idx),
completed=path_idx,
refresh=True,
)
except (KeyboardInterrupt, StopIteration) as e:
# if exception is raised here, MultiPathfinderResult will collect all the successful results and report the results. User is free to abort the process earlier and the results will still be collected and return az.InferenceData.
if isinstance(e, StopIteration):
Expand Down
Loading