Skip to content

Commit f9723d0

Browse files
brentyikerrj
authored andcommitted
Clean up ns-train {method} --help for not-yet-installed external methods (nerfstudio-project#2760)
* Clean up `ns-train {method} --help` for not-yet-installed external methods * Ruff * Ruff * add clearer print statement * Types? Not sure if this fixes it --------- Co-authored-by: Justin Kerr <[email protected]>
1 parent e62d02d commit f9723d0

File tree

2 files changed

+32
-27
lines changed

2 files changed

+32
-27
lines changed

nerfstudio/configs/external_methods.py

+29-24
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,15 @@
1414

1515

1616
"""This file contains the configuration for external methods which are not included in this repository."""
17+
import inspect
1718
import subprocess
1819
import sys
19-
from dataclasses import dataclass, field
20-
from typing import Any, Dict, List, Optional, Tuple, cast
20+
from dataclasses import dataclass
21+
from typing import Dict, List, Optional, Tuple
2122

23+
import tyro
2224
from rich.prompt import Confirm
2325

24-
from nerfstudio.engine.trainer import TrainerConfig
2526
from nerfstudio.utils.rich_utils import CONSOLE
2627

2728

@@ -177,21 +178,30 @@ class ExternalMethod:
177178

178179

179180
@dataclass
180-
class ExternalMethodTrainerConfig(TrainerConfig):
181-
"""
182-
Trainer config for external methods which does not have an implementation in this repository.
181+
class ExternalMethodDummyTrainerConfig:
182+
"""Dummy trainer config for external methods (a) which do not have an
183+
implementation in this repository, and (b) are not yet installed. When this
184+
config is instantiated, we give the user the option to install the method.
183185
"""
184186

185-
_method: ExternalMethod = field(default=cast(ExternalMethod, None))
187+
# tyro.conf.Suppress will prevent these fields from appearing as CLI arguments.
188+
method_name: tyro.conf.Suppress[str]
189+
method: tyro.conf.Suppress[ExternalMethod]
190+
191+
def __post_init__(self):
192+
"""Offer to install an external method."""
186193

187-
def handle_print_information(self, *_args, **_kwargs):
188-
"""Prints the method information and exits."""
189-
CONSOLE.print(self._method.instructions)
190-
if self._method.pip_package and Confirm.ask(
194+
# Don't trigger install message from get_external_methods() below; only
195+
# if this dummy object is instantiated from the CLI.
196+
if inspect.stack()[2].function == "get_external_methods":
197+
return
198+
199+
CONSOLE.print(self.method.instructions)
200+
if self.method.pip_package and Confirm.ask(
191201
"\nWould you like to run the install it now?", default=False, console=CONSOLE
192202
):
193203
# Install the method
194-
install_command = f"{sys.executable} -m pip install {self._method.pip_package}"
204+
install_command = f"{sys.executable} -m pip install {self.method.pip_package}"
195205
CONSOLE.print(f"Running: [cyan]{install_command}[/cyan]")
196206
result = subprocess.run(install_command, shell=True, check=False)
197207
if result.returncode != 0:
@@ -200,20 +210,15 @@ def handle_print_information(self, *_args, **_kwargs):
200210

201211
sys.exit(0)
202212

203-
def __getattribute__(self, __name: str) -> Any:
204-
out = object.__getattribute__(self, __name)
205-
if callable(out) and __name not in {"handle_print_information"} and not __name.startswith("__"):
206-
# We exit early, displaying the message
207-
return self.handle_print_information
208-
return out
209-
210213

211-
def get_external_methods() -> Tuple[Dict[str, TrainerConfig], Dict[str, str]]:
214+
def get_external_methods() -> Tuple[Dict[str, ExternalMethodDummyTrainerConfig], Dict[str, str]]:
212215
"""Returns the external methods trainer configs and the descriptions."""
213-
method_configs = {}
214-
descriptions = {}
216+
method_configs: Dict[str, ExternalMethodDummyTrainerConfig] = {}
217+
descriptions: Dict[str, str] = {}
215218
for external_method in external_methods:
216219
for config_slug, config_description in external_method.configurations:
217-
method_configs[config_slug] = ExternalMethodTrainerConfig(method_name=config_slug, _method=external_method)
218-
descriptions[config_slug] = f"""[External] {config_description}"""
220+
method_configs[config_slug] = ExternalMethodDummyTrainerConfig(
221+
method_name=config_slug, method=external_method
222+
)
223+
descriptions[config_slug] = f"""[External, run 'ns-train {config_slug}' to install] {config_description}"""
219224
return method_configs, descriptions

nerfstudio/configs/method_configs.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@
1919
from __future__ import annotations
2020

2121
from collections import OrderedDict
22-
from typing import Dict
22+
from typing import Dict, Union
2323

2424
import tyro
2525

2626
from nerfstudio.cameras.camera_optimizers import CameraOptimizerConfig
2727
from nerfstudio.configs.base_config import ViewerConfig
28-
from nerfstudio.configs.external_methods import get_external_methods
28+
from nerfstudio.configs.external_methods import ExternalMethodDummyTrainerConfig, get_external_methods
2929
from nerfstudio.data.datamanagers.base_datamanager import VanillaDataManager, VanillaDataManagerConfig
3030
from nerfstudio.data.datamanagers.full_images_datamanager import FullImageDatamanagerConfig
3131
from nerfstudio.data.datamanagers.parallel_datamanager import ParallelDataManagerConfig
@@ -66,7 +66,7 @@
6666
from nerfstudio.pipelines.dynamic_batch import DynamicBatchPipelineConfig
6767
from nerfstudio.plugins.registry import discover_methods
6868

69-
method_configs: Dict[str, TrainerConfig] = {}
69+
method_configs: Dict[str, Union[TrainerConfig, ExternalMethodDummyTrainerConfig]] = {}
7070
descriptions = {
7171
"nerfacto": "Recommended real-time model tuned for real captures. This model will be continually updated.",
7272
"depth-nerfacto": "Nerfacto with depth supervision.",

0 commit comments

Comments
 (0)