Skip to content

Commit acc5462

Browse files
author
Noé Pion
committed
feat: improve launch device
1 parent cef9589 commit acc5462

File tree

3 files changed

+27
-2
lines changed

3 files changed

+27
-2
lines changed

nerfstudio/configs/base_config.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ class MachineConfig(PrintableConfig):
6868
"""current machine's rank (for DDP)"""
6969
dist_url: str = "auto"
7070
"""distributed connection point (for DDP)"""
71-
device_type: Literal["cpu", "cuda", "mps"] = "cuda"
72-
"""device type to use for training"""
71+
device_type: Literal["cpu", "cuda", "mps"] | None = None
72+
"""device type to use for training. If none set, script will do its best to set the value."""
7373

7474

7575
@dataclass

nerfstudio/scripts/train.py

+7
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
from nerfstudio.configs.method_configs import AnnotatedBaseConfigUnion
6363
from nerfstudio.engine.trainer import TrainerConfig
6464
from nerfstudio.utils import comms, profiler
65+
from nerfstudio.utils.best_device import get_best_device
6566
from nerfstudio.utils.rich_utils import CONSOLE
6667

6768
DEFAULT_TIMEOUT = timedelta(minutes=30)
@@ -238,6 +239,12 @@ def main(config: TrainerConfig) -> None:
238239
CONSOLE.log(f"Loading pre-set config from: {config.load_config}")
239240
config = yaml.load(config.load_config.read_text(), Loader=yaml.Loader)
240241

242+
if not hasattr(config.machine, "device_type") or config.machine.device_type is None:
243+
device_type, reason = get_best_device()
244+
config.machine.device_type = device_type
245+
CONSOLE.log(f"No device specified: {reason}")
246+
CONSOLE.log("You can force a certain device type with --machine.device_type [cuda|mps|cpu]")
247+
241248
config.set_timestamp()
242249

243250
# print and save config

nerfstudio/utils/best_device.py

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from typing import Literal, Tuple
2+
3+
import torch
4+
5+
6+
def get_best_device() -> Tuple[Literal["cpu", "cuda", "mps"], str]:
7+
"""Determine the best available device to run nerfstudio inference.
8+
9+
Returns:
10+
tuple: (device_type, reason) where device_type is the selected device and
11+
reason is an explanation of why it was chosen
12+
"""
13+
if torch.cuda.is_available():
14+
return "cuda", "CUDA GPU available - using for optimal performance"
15+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
16+
return "mps", "Apple Metal (MPS) available - using for accelerated performance"
17+
else:
18+
return "cpu", "No GPU/MPS detected - falling back to CPU"

0 commit comments

Comments
 (0)