Skip to content

Commit

Permalink
Various minor fixes for typos and types (#105)
Browse files Browse the repository at this point in the history
* minor fixes for typos and types

* run isort
  • Loading branch information
brentyi authored Jul 8, 2022
1 parent 7428011 commit 14d030f
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 20 deletions.
12 changes: 6 additions & 6 deletions pyrad/engine/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""
Code to train model.
"""
import functools
import logging
import os
from typing import Dict
Expand Down Expand Up @@ -82,15 +83,14 @@ def setup(self, test_mode=False):
self.graph.register_callbacks()

@classmethod
def get_aggregated_loss(cls, loss_dict: Dict[str, torch.tensor]):
def get_aggregated_loss(cls, loss_dict: Dict[str, torch.Tensor]) -> torch.Tensor:
"""Returns the aggregated losses and the scalar for calling .backwards() on.
# TODO: move this out to another file/class/etc.
"""
loss_sum = 0.0
for loss_name in loss_dict.keys():
# TODO(ethan): add loss weightings here from a config
loss_sum += loss_dict[loss_name]
return loss_sum
# TODO(ethan): add loss weightings here from a config
# e.g. weighted_losses = map(lambda k: some_weight_dict[k] * loss_dict[k], loss_dict.keys())
weighted_losses = loss_dict.values()
return functools.reduce(torch.add, weighted_losses)

def train(self) -> None:
"""Train the model."""
Expand Down
2 changes: 1 addition & 1 deletion pyrad/fields/instant_ngp_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

try:
import tinycudann as tcnn
except ImportError as e:
except ImportError:
# tinycudann module doesn't exist
pass

Expand Down
2 changes: 1 addition & 1 deletion pyrad/graphs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def forward_after_ray_generator(self, ray_bundle: RayBundle, batch: Union[str, D
"""Run forward starting with a ray bundle."""
intersected_ray_bundle = self.collider(ray_bundle)

if isinstance(batch, type(None)):
if batch is None:
# during inference, keep all rays
outputs = self.get_outputs(intersected_ray_bundle)
return outputs
Expand Down
7 changes: 3 additions & 4 deletions pyrad/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@

"""Structured config classes"""

from dataclasses import MISSING, dataclass
from dataclasses import dataclass
from typing import Any, Dict, Optional
from omegaconf import DictConfig
from omegaconf import DictConfig, MISSING


@dataclass
Expand Down Expand Up @@ -56,8 +56,7 @@ class TrainerConfig:
steps_per_save: int = MISSING
steps_per_test: int = MISSING
max_num_iterations: int = MISSING
# additional optional parameters here
resume_train: Optional[ResumeTrainConfig] = None
resume_train: ResumeTrainConfig = MISSING


@dataclass
Expand Down
6 changes: 5 additions & 1 deletion pyrad/utils/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,11 @@ def update_time(self, func_name: str, start_time: float, end_time: float):
def print_profile(self):
"""helper to print out the profiler stats"""
logging.info("Printing profiling stats, from longest to shortest duration in seconds")
sorted_keys = [k for k, _ in sorted(self.profiler_dict.items(), key=lambda item: item[1]["val"], reverse=True)]
sorted_keys = sorted(
self.profiler_dict.keys(),
key=lambda k: self.profiler_dict[k]["val"],
reverse=True,
)
for k in sorted_keys:
val = f"{self.profiler_dict[k]['val']:0.4f}"
print(f"{k:<20}: {val:<20}")
13 changes: 6 additions & 7 deletions scripts/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
run_train_nerf.py
"""

import datetime
import logging
import os
import random
Expand Down Expand Up @@ -56,7 +55,7 @@ def _distributed_worker(
machine_rank: int,
dist_url: str,
config: Config,
timeout: datetime = DEFAULT_TIMEOUT,
timeout: timedelta = DEFAULT_TIMEOUT,
) -> Any:
"""Spawned distributed worker that handles the initialization of process group and handles the
training process on multiple processes
Expand All @@ -70,8 +69,8 @@ def _distributed_worker(
dist_url (str): url to connect to for distributed jobs, including protocol
e.g. "tcp://127.0.0.1:8686".
Can be set to "auto" to automatically select a free port on localhost
config (Config): config file specifying trainng regimen
timeout (datetime, optional): timeout of the distributed workers Defaults to DEFAULT_TIMEOUT.
config (Config): config file specifying training regimen
timeout (timedelta, optional): timeout of the distributed workers Defaults to DEFAULT_TIMEOUT.
Raises:
e: Exception in initializing the process group
Expand Down Expand Up @@ -137,7 +136,7 @@ def launch(
machine_rank: int = 0,
dist_url: str = "auto",
config: Config = None,
timeout: datetime = DEFAULT_TIMEOUT,
timeout: timedelta = DEFAULT_TIMEOUT,
) -> None:
"""Function that spawns muliple processes to call on main_func
Expand All @@ -147,8 +146,8 @@ def launch(
num_machines (int, optional): total number of machines
machine_rank (int, optional): rank of this machine. Defaults to 0.
dist_url (str, optional): url to connect to for distributed jobs. Defaults to "auto".
config (Config, optional): config file specifying trainng regimen Defaults to None.
timeout (datetime, optional): timeout of the distributed workers Defaults to DEFAULT_TIMEOUT.
config (Config, optional): config file specifying training regimen. Defaults to None.
timeout (timedelta, optional): timeout of the distributed workers. Defaults to DEFAULT_TIMEOUT.
"""
world_size = num_machines * num_gpus_per_machine
if world_size == 0:
Expand Down

0 comments on commit 14d030f

Please sign in to comment.