Skip to content

Commit 3fdf9c2

Browse files
committed
Type hints
1 parent 8885fba commit 3fdf9c2

File tree

1 file changed

+13
-9
lines changed

1 file changed

+13
-9
lines changed

Diff for: experiment/experiment_manager.py

+13-9
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from boilr.utils import linear_anneal
88
from torch import optim
99
from torch.optim.optimizer import Optimizer
10-
10+
from typing import Optional
1111
from models.lvae import LadderVAE
1212
from .data import DatasetLoader
1313

@@ -319,11 +319,9 @@ def _make_run_description(args: argparse.Namespace) -> str:
319319
s += ',' + args.additional_descr
320320
return s
321321

322-
def forward_pass(self, x, y=None):
323-
"""
324-
Simple single-pass model evaluation. It consists of a forward pass
325-
and computation of all necessary losses and metrics.
326-
"""
322+
def forward_pass(self,
323+
x: torch.Tensor,
324+
y: Optional[torch.Tensor] = None) -> dict:
327325

328326
# Forward pass
329327
x = x.to(self.device, non_blocking=True)
@@ -369,14 +367,20 @@ def forward_pass(self, x, y=None):
369367
return output
370368

371369
@classmethod
372-
def train_log_str(cls, summaries, step, epoch=None):
370+
def train_log_str(cls,
371+
summaries: dict,
372+
step: int,
373+
epoch: Optional[int] = None) -> str:
373374
s = " [step {}] loss: {:.5g} ELBO: {:.5g} recons: {:.3g} KL: {:.3g}"
374375
s = s.format(step, summaries['loss/loss'], summaries['elbo/elbo'],
375376
summaries['elbo/recons'], summaries['elbo/kl'])
376377
return s
377378

378379
@classmethod
379-
def test_log_str(cls, summaries, step, epoch=None):
380+
def test_log_str(cls,
381+
summaries: dict,
382+
step: int,
383+
epoch: Optional[int] = None) -> str:
380384
s = " "
381385
if epoch is not None:
382386
s += "[step {}, epoch {}] ".format(step, epoch)
@@ -396,7 +400,7 @@ def test_log_str(cls, summaries, step, epoch=None):
396400
return s
397401

398402
@classmethod
399-
def get_metrics_dict(cls, results):
403+
def get_metrics_dict(cls, results: dict) -> dict:
400404
metrics_dict = {
401405
'loss/loss': results['loss'].item(),
402406
'elbo/elbo': results['elbo'].item(),

0 commit comments

Comments
 (0)