diff --git a/pyrad/engine/trainer.py b/pyrad/engine/trainer.py index eca8d307ec..b8fdbe9783 100644 --- a/pyrad/engine/trainer.py +++ b/pyrad/engine/trainer.py @@ -15,6 +15,7 @@ """ Code to train model. """ +import functools import logging import os from typing import Dict @@ -82,15 +83,14 @@ def setup(self, test_mode=False): self.graph.register_callbacks() @classmethod - def get_aggregated_loss(cls, loss_dict: Dict[str, torch.tensor]): + def get_aggregated_loss(cls, loss_dict: Dict[str, torch.Tensor]) -> torch.Tensor: """Returns the aggregated losses and the scalar for calling .backwards() on. # TODO: move this out to another file/class/etc. """ - loss_sum = 0.0 - for loss_name in loss_dict.keys(): - # TODO(ethan): add loss weightings here from a config - loss_sum += loss_dict[loss_name] - return loss_sum + # TODO(ethan): add loss weightings here from a config + # e.g. weighted_losses = map(lambda k: some_weight_dict[k] * loss_dict[k], loss_dict.keys()) + weighted_losses = loss_dict.values() + return functools.reduce(torch.add, weighted_losses) def train(self) -> None: """Train the model.""" diff --git a/pyrad/fields/instant_ngp_field.py b/pyrad/fields/instant_ngp_field.py index 9870444e35..12bc0bfc9f 100644 --- a/pyrad/fields/instant_ngp_field.py +++ b/pyrad/fields/instant_ngp_field.py @@ -31,7 +31,7 @@ try: import tinycudann as tcnn -except ImportError as e: +except ImportError: # tinycudann module doesn't exist pass diff --git a/pyrad/graphs/base.py b/pyrad/graphs/base.py index d65a977edc..c1c9cf05c4 100644 --- a/pyrad/graphs/base.py +++ b/pyrad/graphs/base.py @@ -140,7 +140,7 @@ def forward_after_ray_generator(self, ray_bundle: RayBundle, batch: Union[str, D """Run forward starting with a ray bundle.""" intersected_ray_bundle = self.collider(ray_bundle) - if isinstance(batch, type(None)): + if batch is None: # during inference, keep all rays outputs = self.get_outputs(intersected_ray_bundle) return outputs diff --git a/pyrad/utils/config.py b/pyrad/utils/config.py index 8b4fa1afb5..f6ff416768 100644 --- a/pyrad/utils/config.py +++ b/pyrad/utils/config.py @@ -14,9 +14,9 @@ """Structured config classes""" -from dataclasses import MISSING, dataclass +from dataclasses import dataclass from typing import Any, Dict, Optional -from omegaconf import DictConfig +from omegaconf import DictConfig, MISSING @dataclass @@ -56,8 +56,7 @@ class TrainerConfig: steps_per_save: int = MISSING steps_per_test: int = MISSING max_num_iterations: int = MISSING - # additional optional parameters here - resume_train: Optional[ResumeTrainConfig] = None + resume_train: ResumeTrainConfig = MISSING @dataclass diff --git a/pyrad/utils/profiler.py b/pyrad/utils/profiler.py index fc409a82c7..ee2633a0c8 100644 --- a/pyrad/utils/profiler.py +++ b/pyrad/utils/profiler.py @@ -77,7 +77,11 @@ def update_time(self, func_name: str, start_time: float, end_time: float): def print_profile(self): """helper to print out the profiler stats""" logging.info("Printing profiling stats, from longest to shortest duration in seconds") - sorted_keys = [k for k, _ in sorted(self.profiler_dict.items(), key=lambda item: item[1]["val"], reverse=True)] + sorted_keys = sorted( + self.profiler_dict.keys(), + key=lambda k: self.profiler_dict[k]["val"], + reverse=True, + ) for k in sorted_keys: val = f"{self.profiler_dict[k]['val']:0.4f}" print(f"{k:<20}: {val:<20}") diff --git a/scripts/run_train.py b/scripts/run_train.py index 5d9d3c60c4..b674c2ada6 100644 --- a/scripts/run_train.py +++ b/scripts/run_train.py @@ -2,7 +2,6 @@ run_train_nerf.py """ -import datetime import logging import os import random @@ -56,7 +55,7 @@ def _distributed_worker( machine_rank: int, dist_url: str, config: Config, - timeout: datetime = DEFAULT_TIMEOUT, + timeout: timedelta = DEFAULT_TIMEOUT, ) -> Any: """Spawned distributed worker that handles the initialization of process group and handles the training process on multiple processes @@ -70,8 +69,8 @@ def _distributed_worker( dist_url (str): url to connect to for distributed jobs, including protocol e.g. "tcp://127.0.0.1:8686". Can be set to "auto" to automatically select a free port on localhost - config (Config): config file specifying trainng regimen - timeout (datetime, optional): timeout of the distributed workers Defaults to DEFAULT_TIMEOUT. + config (Config): config file specifying training regimen + timeout (timedelta, optional): timeout of the distributed workers Defaults to DEFAULT_TIMEOUT. Raises: e: Exception in initializing the process group @@ -137,7 +136,7 @@ def launch( machine_rank: int = 0, dist_url: str = "auto", config: Config = None, - timeout: datetime = DEFAULT_TIMEOUT, + timeout: timedelta = DEFAULT_TIMEOUT, ) -> None: """Function that spawns muliple processes to call on main_func @@ -147,8 +146,8 @@ def launch( num_machines (int, optional): total number of machines machine_rank (int, optional): rank of this machine. Defaults to 0. dist_url (str, optional): url to connect to for distributed jobs. Defaults to "auto". - config (Config, optional): config file specifying trainng regimen Defaults to None. - timeout (datetime, optional): timeout of the distributed workers Defaults to DEFAULT_TIMEOUT. + config (Config, optional): config file specifying training regimen. Defaults to None. + timeout (timedelta, optional): timeout of the distributed workers. Defaults to DEFAULT_TIMEOUT. """ world_size = num_machines * num_gpus_per_machine if world_size == 0: