Skip to content

Commit

Permalink
improve ETA precision
Browse files Browse the repository at this point in the history
  • Loading branch information
getzze committed Oct 25, 2024
1 parent fff8761 commit 0a8d5d2
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions sleap/nn/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,9 @@ def process_batch(ex):

return ex

# Compile loop examples before starting time to improve ETA
examples = self.pipeline.make_dataset()

# Loop over data batches with optional progress reporting.
if self.verbosity == "rich":
with rich.progress.Progress(
Expand All @@ -433,7 +436,7 @@ def process_batch(ex):
) as progress:
task = progress.add_task("Predicting...", total=len(data_provider))
last_report = time()
for ex in self.pipeline.make_dataset():
for ex in examples:
ex = process_batch(ex)
progress.update(task, advance=len(ex["frame_ind"]))

Expand All @@ -453,7 +456,7 @@ def process_batch(ex):
last_report = time()
t0_all = time()
t0_batch = time()
for ex in self.pipeline.make_dataset():
for ex in examples:
# Process batch of examples.
ex = process_batch(ex)

Expand Down Expand Up @@ -490,7 +493,7 @@ def process_batch(ex):
# Return results.
yield ex
else:
for ex in self.pipeline.make_dataset():
for ex in examples:
yield process_batch(ex)

def predict(
Expand Down

0 comments on commit 0a8d5d2

Please sign in to comment.