From 2779995072156b5a80eaf572f65a1cc582cf82f5 Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Thu, 7 Jul 2022 17:48:03 -0700 Subject: [PATCH] minor fixes for typos and types --- pyrad/engine/trainer.py | 12 ++++++------ pyrad/fields/instant_ngp_field.py | 2 +- pyrad/graphs/base.py | 2 +- pyrad/utils/config.py | 7 +++---- pyrad/utils/profiler.py | 6 +++++- scripts/run_train.py | 13 ++++++------- 6 files changed, 22 insertions(+), 20 deletions(-) diff --git a/pyrad/engine/trainer.py b/pyrad/engine/trainer.py index eca8d307ec5..61d496bcb89 100644 --- a/pyrad/engine/trainer.py +++ b/pyrad/engine/trainer.py @@ -23,6 +23,7 @@ import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP from torchtyping import TensorType +import functools from pyrad.cameras.rays import RayBundle from pyrad.data.dataloader import EvalDataloader, setup_dataset_eval, setup_dataset_train @@ -82,15 +83,14 @@ def setup(self, test_mode=False): self.graph.register_callbacks() @classmethod - def get_aggregated_loss(cls, loss_dict: Dict[str, torch.tensor]): + def get_aggregated_loss(cls, loss_dict: Dict[str, torch.Tensor]) -> torch.Tensor: """Returns the aggregated losses and the scalar for calling .backwards() on. # TODO: move this out to another file/class/etc. """ - loss_sum = 0.0 - for loss_name in loss_dict.keys(): - # TODO(ethan): add loss weightings here from a config - loss_sum += loss_dict[loss_name] - return loss_sum + # TODO(ethan): add loss weightings here from a config + # e.g. weighted_losses = map(lambda k: some_weight_dict[k] * loss_dict[k], loss_dict.keys()) + weighted_losses = loss_dict.values() + return functools.reduce(torch.add, weighted_losses) def train(self) -> None: """Train the model.""" diff --git a/pyrad/fields/instant_ngp_field.py b/pyrad/fields/instant_ngp_field.py index 9870444e354..12bc0bfc9fe 100644 --- a/pyrad/fields/instant_ngp_field.py +++ b/pyrad/fields/instant_ngp_field.py @@ -31,7 +31,7 @@ try: import tinycudann as tcnn -except ImportError as e: +except ImportError: # tinycudann module doesn't exist pass diff --git a/pyrad/graphs/base.py b/pyrad/graphs/base.py index 8c4d1b9f460..9230631a0dd 100644 --- a/pyrad/graphs/base.py +++ b/pyrad/graphs/base.py @@ -142,7 +142,7 @@ def forward_after_ray_generator(self, ray_bundle: RayBundle, batch: Union[str, D """Run forward starting with a ray bundle.""" intersected_ray_bundle = self.collider(ray_bundle) - if isinstance(batch, type(None)): + if batch is None: # during inference, keep all rays outputs = self.get_outputs(intersected_ray_bundle) return outputs diff --git a/pyrad/utils/config.py b/pyrad/utils/config.py index 8b4fa1afb55..f6ff4167686 100644 --- a/pyrad/utils/config.py +++ b/pyrad/utils/config.py @@ -14,9 +14,9 @@ """Structured config classes""" -from dataclasses import MISSING, dataclass +from dataclasses import dataclass from typing import Any, Dict, Optional -from omegaconf import DictConfig +from omegaconf import DictConfig, MISSING @dataclass @@ -56,8 +56,7 @@ class TrainerConfig: steps_per_save: int = MISSING steps_per_test: int = MISSING max_num_iterations: int = MISSING - # additional optional parameters here - resume_train: Optional[ResumeTrainConfig] = None + resume_train: ResumeTrainConfig = MISSING @dataclass diff --git a/pyrad/utils/profiler.py b/pyrad/utils/profiler.py index fc409a82c7d..ee2633a0c89 100644 --- a/pyrad/utils/profiler.py +++ b/pyrad/utils/profiler.py @@ -77,7 +77,11 @@ def update_time(self, func_name: str, start_time: float, end_time: float): def print_profile(self): """helper to print out the profiler stats""" logging.info("Printing profiling stats, from longest to shortest duration in seconds") - sorted_keys = [k for k, _ in sorted(self.profiler_dict.items(), key=lambda item: item[1]["val"], reverse=True)] + sorted_keys = sorted( + self.profiler_dict.keys(), + key=lambda k: self.profiler_dict[k]["val"], + reverse=True, + ) for k in sorted_keys: val = f"{self.profiler_dict[k]['val']:0.4f}" print(f"{k:<20}: {val:<20}") diff --git a/scripts/run_train.py b/scripts/run_train.py index 5d9d3c60c48..bc007ad5d17 100644 --- a/scripts/run_train.py +++ b/scripts/run_train.py @@ -2,7 +2,6 @@ run_train_nerf.py """ -import datetime import logging import os import random @@ -56,7 +55,7 @@ def _distributed_worker( machine_rank: int, dist_url: str, config: Config, - timeout: datetime = DEFAULT_TIMEOUT, + timeout: timedelta = DEFAULT_TIMEOUT, ) -> Any: """Spawned distributed worker that handles the initialization of process group and handles the training process on multiple processes @@ -70,8 +69,8 @@ def _distributed_worker( dist_url (str): url to connect to for distributed jobs, including protocol e.g. "tcp://127.0.0.1:8686". Can be set to "auto" to automatically select a free port on localhost - config (Config): config file specifying trainng regimen - timeout (datetime, optional): timeout of the distributed workers Defaults to DEFAULT_TIMEOUT. + config (Config): config file specifying training regimen + timeout (timedelta, optional): timeout of the distributed workers Defaults to DEFAULT_TIMEOUT. Raises: e: Exception in initializing the process group @@ -137,7 +136,7 @@ def launch( machine_rank: int = 0, dist_url: str = "auto", config: Config = None, - timeout: datetime = DEFAULT_TIMEOUT, + timeout: timedelta = DEFAULT_TIMEOUT, ) -> None: """Function that spawns muliple processes to call on main_func @@ -147,8 +146,8 @@ def launch( num_machines (int, optional): total number of machines machine_rank (int, optional): rank of this machine. Defaults to 0. dist_url (str, optional): url to connect to for distributed jobs. Defaults to "auto". - config (Config, optional): config file specifying trainng regimen Defaults to None. - timeout (datetime, optional): timeout of the distributed workers Defaults to DEFAULT_TIMEOUT. + config (Config, optional): config file specifying training regimen. Defaults to None. + timeout (timedelta, optional): timeout of the distributed workers Defaults to DEFAULT_TIMEOUT. """ world_size = num_machines * num_gpus_per_machine if world_size == 0: