diff --git a/scripts/completions/install.py b/scripts/completions/install.py index 636ed19798..98e67fea27 100755 --- a/scripts/completions/install.py +++ b/scripts/completions/install.py @@ -10,27 +10,29 @@ import stat import subprocess import sys -from typing import List, Union +from typing import List, Optional, Union import tyro from rich.console import Console from rich.prompt import Confirm from typing_extensions import Literal, assert_never +from typing_extensions import get_args as typing_get_args + +if sys.version_info < (3, 10): + import importlib_metadata +else: + from importlib import metadata as importlib_metadata ConfigureMode = Literal["install", "uninstall"] ShellType = Literal["zsh", "bash"] CONSOLE = Console(width=120) +HEADER_LINE = "# Source nerfstudio autocompletions." + -ENTRYPOINTS = [ - "ns-install-cli", - "ns-process-data", - "ns-download-data", - "ns-train", - "ns-eval", - "ns-render", - "ns-dev-test", -] +def _get_all_entry_points() -> List[str]: + entry_points = importlib_metadata.distribution("nerfstudio").entry_points + return [x.name for x in entry_points] def _check_tyro_cli(script_path: pathlib.Path) -> bool: @@ -127,26 +129,41 @@ def _exclamation() -> str: return random.choice(["Cool", "Nice", "Neat", "Great", "Exciting", "Excellent", "Ok"]) + "!" -def _update_rc( - completions_dir: pathlib.Path, - mode: ConfigureMode, - shell: ShellType, -) -> None: - """Try to add a `source /.../completions/setup.{shell}` line automatically to a user's zshrc or bashrc. +def _get_deactivate_script(commands: List[str], shell: Optional[ShellType], add_header=True) -> str: + if shell is None: + # Install the universal script + result_script = [] + for shell_type in typing_get_args(ShellType): + result_script.append(f'if [ -n "${shell_type.upper()}_VERSION" ]; then') + result_script.append(_get_deactivate_script(commands, shell_type, add_header=False)) + result_script.append("fi") + source_lines = "\n".join(result_script) - Args: - completions_dir: Path to location of this script. - shell: Shell to install completion scripts for. - mode: Install or uninstall completions. - """ + elif shell == "zsh": + source_lines = "\n".join([f"unset '_comps[{command}]' &> /dev/null" for command in commands]) + elif shell == "bash": + source_lines = "\n".join([f"complete -r {command} &> /dev/null" for command in commands]) + else: + assert_never(shell) - # Install or uninstall `source_line`. - header_line = "# Source nerfstudio autocompletions." - if shell == "zsh": + if add_header: + source_lines = f"\n{HEADER_LINE}\n{source_lines}" + return source_lines + + +def _get_source_script(completions_dir: pathlib.Path, shell: Optional[ShellType], add_header=True) -> str: + if shell is None: + # Install the universal script + result_script = [] + for shell_type in typing_get_args(ShellType): + result_script.append(f'if [ -n "${shell_type.upper()}_VERSION" ]; then') + result_script.append(_get_source_script(completions_dir, shell_type, add_header=False)) + result_script.append("fi") + source_lines = "\n".join(result_script) + + elif shell == "zsh": source_lines = "\n".join( [ - "", - header_line, "if ! command -v compdef &> /dev/null; then", " autoload -Uz compinit", " compinit", @@ -157,20 +174,38 @@ def _update_rc( elif shell == "bash": source_lines = "\n".join( [ - "", - header_line, f"source {completions_dir / 'setup.bash'}", ] ) else: assert_never(shell) + if add_header: + source_lines = f"\n{HEADER_LINE}\n{source_lines}" + return source_lines + + +def _update_rc( + completions_dir: pathlib.Path, + mode: ConfigureMode, + shell: ShellType, +) -> None: + """Try to add a `source /.../completions/setup.{shell}` line automatically to a user's zshrc or bashrc. + + Args: + completions_dir: Path to location of this script. + shell: Shell to install completion scripts for. + mode: Install or uninstall completions. + """ + + # Install or uninstall `source_line`. + source_lines = _get_source_script(completions_dir, shell) rc_path = pathlib.Path(os.environ["HOME"]) / f".{shell}rc" # Always try to uninstall previous completions. rc_source = rc_path.read_text() - while header_line in rc_source: - before_install, _, after_install = rc_source.partition(header_line) + while HEADER_LINE in rc_source: + before_install, _, after_install = rc_source.partition(HEADER_LINE) source_file, _, after_install = after_install.partition("\nsource ")[2].partition("\n") assert source_file.endswith(f"/completions/setup.{shell}") rc_source = before_install + after_install @@ -189,6 +224,108 @@ def _update_rc( assert mode == "uninstall" +def _update_conda_scripts( + commands: List[str], + completions_dir: pathlib.Path, + mode: ConfigureMode, +) -> None: + """Try to add a `source /.../completions/setup.{shell}` line automatically to conda's activation scripts. + + Args: + completions_dir: Path to location of this script. + mode: Install or uninstall completions. + """ + + # Install or uninstall `source_line`. + activate_source_lines = _get_source_script(completions_dir, None) + deactivate_source_lines = _get_deactivate_script(commands, None) + + conda_path = pathlib.Path(os.environ["CONDA_PREFIX"]) + activate_path = conda_path / "etc/conda/activate.d/nerfstudio_activate.sh" + deactivate_path = conda_path / "etc/conda/deactivate.d/nerfstudio_deactivate.sh" + if mode == "uninstall": + if activate_path.exists(): + os.remove(activate_path) + if deactivate_path.exists(): + os.remove(deactivate_path) + CONSOLE.log(f":broom: Existing completions uninstalled from {conda_path}.") + elif mode == "install": + # Install completions. + activate_path.parent.mkdir(exist_ok=True, parents=True) + deactivate_path.parent.mkdir(exist_ok=True, parents=True) + with activate_path.open("w+", encoding="utf8") as f: + f.write(activate_source_lines) + with deactivate_path.open("w+", encoding="utf8") as f: + f.write(deactivate_source_lines) + CONSOLE.log( + f":person_gesturing_ok: Completions installed to {conda_path}. {_exclamation()} Reactivate the environment" + " to try them out." + ) + else: + assert_never(mode) + + +def _get_conda_path() -> Optional[pathlib.Path]: + """ + Returns the path to the conda environment if + the nerfstudio package is installed in one. + """ + conda_path = None + if "CONDA_PREFIX" in os.environ: + # Conda is active, we will check if the Nerfstudio is installed in the conda env. + distribution = importlib_metadata.distribution("nerfstudio") + if str(distribution.locate_file("nerfstudio")).startswith(os.environ["CONDA_PREFIX"]): + conda_path = pathlib.Path(os.environ["CONDA_PREFIX"]) + return conda_path + + +def _generate_completions_files( + completions_dir: pathlib.Path, + scripts_dir: pathlib.Path, + shells_supported: List[ShellType], + shells_found: List[ShellType], +) -> None: + # Set to True to install completions for scripts as well. + include_scripts = False + + # Find tyro CLIs. + script_paths = list(filter(_check_tyro_cli, scripts_dir.glob("**/*.py"))) if include_scripts else [] + script_names = tuple(p.name for p in script_paths) + assert len(set(script_names)) == len(script_names) + + # Get existing completion files. + existing_completions = set() + for shell in shells_supported: + target_dir = completions_dir / shell + if target_dir.exists(): + existing_completions |= set(target_dir.glob("*")) + + # Get all entry_points. + entry_points = _get_all_entry_points() + + # Run generation jobs. + concurrent_executor = concurrent.futures.ThreadPoolExecutor() + with CONSOLE.status("[bold]:writing_hand: Generating completions...", spinner="bouncingBall"): + completion_paths = list( + concurrent_executor.map( + lambda path_or_entrypoint_and_shell: _generate_completion( + path_or_entrypoint_and_shell[0], path_or_entrypoint_and_shell[1], completions_dir + ), + itertools.product(script_paths + entry_points, shells_found), + ) + ) + + # Delete obsolete completion files. + for unexpected_path in set(p.absolute() for p in existing_completions) - set( + p.absolute() for p in completion_paths + ): + if unexpected_path.is_dir(): + shutil.rmtree(unexpected_path) + elif unexpected_path.exists(): + unexpected_path.unlink() + CONSOLE.log(f":broom: Deleted {unexpected_path}.") + + def main(mode: ConfigureMode = "install") -> None: """Main script. @@ -201,16 +338,24 @@ def main(mode: ConfigureMode = "install") -> None: CONSOLE.log("[bold red]$HOME is not set. Exiting.") return + # Get conda path if in conda environment. + conda_path = _get_conda_path() + # Try to locate the user's bashrc or zshrc. - shells_supported: List[ShellType] = ["zsh", "bash"] - shells_found: List[ShellType] = [] - for shell in shells_supported: - rc_path = pathlib.Path(os.environ["HOME"]) / f".{shell}rc" - if not rc_path.exists(): - CONSOLE.log(f":person_shrugging: {rc_path.name} not found, skipping.") - else: - CONSOLE.log(f":mag: Found {rc_path.name}!") - shells_found.append(shell) + shells_supported: List[ShellType] = list(typing_get_args(ShellType)) + if conda_path is not None: + # Running in conda; we have to support all shells. + shells_found = shells_supported + CONSOLE.log(f":mag: Detected conda environment {conda_path}!") + else: + shells_found: List[ShellType] = [] + for shell in shells_supported: + rc_path = pathlib.Path(os.environ["HOME"]) / f".{shell}rc" + if not rc_path.exists(): + CONSOLE.log(f":person_shrugging: {rc_path.name} not found, skipping.") + else: + CONSOLE.log(f":mag: Found {rc_path.name}!") + shells_found.append(shell) # Get scripts/ directory. completions_dir = pathlib.Path(__file__).absolute().parent @@ -230,48 +375,18 @@ def main(mode: ConfigureMode = "install") -> None: else: CONSOLE.log(f":heavy_check_mark: No existing completions at: {target_dir}.") elif mode == "install": - # Set to True to install completions for scripts as well. - include_scripts = False - - # Find tyro CLIs. - script_paths = list(filter(_check_tyro_cli, scripts_dir.glob("**/*.py"))) if include_scripts else [] - script_names = tuple(p.name for p in script_paths) - assert len(set(script_names)) == len(script_names) - - # Get existing completion files. - existing_completions = set() - for shell in shells_supported: - target_dir = completions_dir / shell - if target_dir.exists(): - existing_completions |= set(target_dir.glob("*")) - - # Run generation jobs. - concurrent_executor = concurrent.futures.ThreadPoolExecutor() - with CONSOLE.status("[bold]:writing_hand: Generating completions...", spinner="bouncingBall"): - completion_paths = list( - concurrent_executor.map( - lambda path_or_entrypoint_and_shell: _generate_completion( - path_or_entrypoint_and_shell[0], path_or_entrypoint_and_shell[1], completions_dir - ), - itertools.product(script_paths + ENTRYPOINTS, shells_found), - ) - ) - - # Delete obsolete completion files. - for unexpected_path in set(p.absolute() for p in existing_completions) - set( - p.absolute() for p in completion_paths - ): - if unexpected_path.is_dir(): - shutil.rmtree(unexpected_path) - elif unexpected_path.exists(): - unexpected_path.unlink() - CONSOLE.log(f":broom: Deleted {unexpected_path}.") + _generate_completions_files(completions_dir, scripts_dir, shells_supported, shells_found) else: assert_never(mode) - # Install or uninstall from bashrc/zshrc. - for shell in shells_found: - _update_rc(completions_dir, mode, shell) + if conda_path is not None: + # In conda environment we add the completitions activation scripts. + commands = _get_all_entry_points() + _update_conda_scripts(commands, completions_dir, mode) + else: + # Install or uninstall from bashrc/zshrc. + for shell in shells_found: + _update_rc(completions_dir, mode, shell) CONSOLE.print("[bold]All done![/bold]")