7
7
from boilr .utils import linear_anneal
8
8
from torch import optim
9
9
from torch .optim .optimizer import Optimizer
10
-
10
+ from typing import Optional
11
11
from models .lvae import LadderVAE
12
12
from .data import DatasetLoader
13
13
@@ -319,11 +319,9 @@ def _make_run_description(args: argparse.Namespace) -> str:
319
319
s += ',' + args .additional_descr
320
320
return s
321
321
322
- def forward_pass (self , x , y = None ):
323
- """
324
- Simple single-pass model evaluation. It consists of a forward pass
325
- and computation of all necessary losses and metrics.
326
- """
322
+ def forward_pass (self ,
323
+ x : torch .Tensor ,
324
+ y : Optional [torch .Tensor ] = None ) -> dict :
327
325
328
326
# Forward pass
329
327
x = x .to (self .device , non_blocking = True )
@@ -369,14 +367,20 @@ def forward_pass(self, x, y=None):
369
367
return output
370
368
371
369
@classmethod
372
- def train_log_str (cls , summaries , step , epoch = None ):
370
+ def train_log_str (cls ,
371
+ summaries : dict ,
372
+ step : int ,
373
+ epoch : Optional [int ] = None ) -> str :
373
374
s = " [step {}] loss: {:.5g} ELBO: {:.5g} recons: {:.3g} KL: {:.3g}"
374
375
s = s .format (step , summaries ['loss/loss' ], summaries ['elbo/elbo' ],
375
376
summaries ['elbo/recons' ], summaries ['elbo/kl' ])
376
377
return s
377
378
378
379
@classmethod
379
- def test_log_str (cls , summaries , step , epoch = None ):
380
+ def test_log_str (cls ,
381
+ summaries : dict ,
382
+ step : int ,
383
+ epoch : Optional [int ] = None ) -> str :
380
384
s = " "
381
385
if epoch is not None :
382
386
s += "[step {}, epoch {}] " .format (step , epoch )
@@ -396,7 +400,7 @@ def test_log_str(cls, summaries, step, epoch=None):
396
400
return s
397
401
398
402
@classmethod
399
- def get_metrics_dict (cls , results ) :
403
+ def get_metrics_dict (cls , results : dict ) -> dict :
400
404
metrics_dict = {
401
405
'loss/loss' : results ['loss' ].item (),
402
406
'elbo/elbo' : results ['elbo' ].item (),
0 commit comments