Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve ns-install-cli #1396

Merged
merged 7 commits into from
Feb 15, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
271 changes: 193 additions & 78 deletions scripts/completions/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,29 @@
import stat
import subprocess
import sys
from typing import List, Union
from typing import List, Optional, Union

import tyro
from rich.console import Console
from rich.prompt import Confirm
from typing_extensions import Literal, assert_never
from typing_extensions import get_args as typing_get_args

if sys.version_info < (3, 10):
import importlib_metadata
else:
from importlib import metadata as importlib_metadata

ConfigureMode = Literal["install", "uninstall"]
ShellType = Literal["zsh", "bash"]

CONSOLE = Console(width=120)
HEADER_LINE = "# Source nerfstudio autocompletions."


ENTRYPOINTS = [
"ns-install-cli",
"ns-process-data",
"ns-download-data",
"ns-train",
"ns-eval",
"ns-render",
"ns-dev-test",
]
def _get_all_entry_points() -> List[str]:
entry_points = importlib_metadata.distribution("nerfstudio").entry_points
return [x.name for x in entry_points]


def _check_tyro_cli(script_path: pathlib.Path) -> bool:
Expand Down Expand Up @@ -127,26 +129,41 @@ def _exclamation() -> str:
return random.choice(["Cool", "Nice", "Neat", "Great", "Exciting", "Excellent", "Ok"]) + "!"


def _update_rc(
completions_dir: pathlib.Path,
mode: ConfigureMode,
shell: ShellType,
) -> None:
"""Try to add a `source /.../completions/setup.{shell}` line automatically to a user's zshrc or bashrc.
def _get_deactivate_script(commands: List[str], shell: Optional[ShellType], add_header=True) -> str:
if shell is None:
# Install the universal script
result_script = []
for shell_type in typing_get_args(ShellType):
result_script.append(f'if [ -n "${shell_type.upper()}_VERSION" ]; then')
result_script.append(_get_deactivate_script(commands, shell_type, add_header=False))
result_script.append("fi")
source_lines = "\n".join(result_script)

Args:
completions_dir: Path to location of this script.
shell: Shell to install completion scripts for.
mode: Install or uninstall completions.
"""
elif shell == "zsh":
source_lines = "\n".join([f"unset '_comps[{command}]' &> /dev/null" for command in commands])
elif shell == "bash":
source_lines = "\n".join([f"complete -r {command} &> /dev/null" for command in commands])
else:
assert_never(shell)

# Install or uninstall `source_line`.
header_line = "# Source nerfstudio autocompletions."
if shell == "zsh":
if add_header:
source_lines = f"\n{HEADER_LINE}\n{source_lines}"
return source_lines


def _get_source_script(completions_dir: pathlib.Path, shell: Optional[ShellType], add_header=True) -> str:
if shell is None:
# Install the universal script
result_script = []
for shell_type in typing_get_args(ShellType):
result_script.append(f'if [ -n "${shell_type.upper()}_VERSION" ]; then')
result_script.append(_get_source_script(completions_dir, shell_type, add_header=False))
result_script.append("fi")
source_lines = "\n".join(result_script)

elif shell == "zsh":
source_lines = "\n".join(
[
"",
header_line,
"if ! command -v compdef &> /dev/null; then",
" autoload -Uz compinit",
" compinit",
Expand All @@ -157,20 +174,38 @@ def _update_rc(
elif shell == "bash":
source_lines = "\n".join(
[
"",
header_line,
f"source {completions_dir / 'setup.bash'}",
]
)
else:
assert_never(shell)

if add_header:
source_lines = f"\n{HEADER_LINE}\n{source_lines}"
return source_lines


def _update_rc(
completions_dir: pathlib.Path,
mode: ConfigureMode,
shell: ShellType,
) -> None:
"""Try to add a `source /.../completions/setup.{shell}` line automatically to a user's zshrc or bashrc.

Args:
completions_dir: Path to location of this script.
shell: Shell to install completion scripts for.
mode: Install or uninstall completions.
"""

# Install or uninstall `source_line`.
source_lines = _get_source_script(completions_dir, shell)
rc_path = pathlib.Path(os.environ["HOME"]) / f".{shell}rc"

# Always try to uninstall previous completions.
rc_source = rc_path.read_text()
while header_line in rc_source:
before_install, _, after_install = rc_source.partition(header_line)
while HEADER_LINE in rc_source:
before_install, _, after_install = rc_source.partition(HEADER_LINE)
source_file, _, after_install = after_install.partition("\nsource ")[2].partition("\n")
assert source_file.endswith(f"/completions/setup.{shell}")
rc_source = before_install + after_install
Expand All @@ -189,6 +224,108 @@ def _update_rc(
assert mode == "uninstall"


def _update_conda_scripts(
commands: List[str],
completions_dir: pathlib.Path,
mode: ConfigureMode,
) -> None:
"""Try to add a `source /.../completions/setup.{shell}` line automatically to conda's activation scripts.

Args:
completions_dir: Path to location of this script.
mode: Install or uninstall completions.
"""

# Install or uninstall `source_line`.
activate_source_lines = _get_source_script(completions_dir, None)
deactivate_source_lines = _get_deactivate_script(commands, None)

conda_path = pathlib.Path(os.environ["CONDA_PREFIX"])
activate_path = conda_path / "etc/conda/activate.d/nerfstudio_activate.sh"
deactivate_path = conda_path / "etc/conda/deactivate.d/nerfstudio_deactivate.sh"
if mode == "uninstall":
if activate_path.exists():
os.remove(activate_path)
if deactivate_path.exists():
os.remove(deactivate_path)
CONSOLE.log(f":broom: Existing completions uninstalled from {conda_path}.")
elif mode == "install":
# Install completions.
activate_path.parent.mkdir(exist_ok=True, parents=True)
brentyi marked this conversation as resolved.
Show resolved Hide resolved
deactivate_path.parent.mkdir(exist_ok=True, parents=True)
with activate_path.open("w+", encoding="utf8") as f:
brentyi marked this conversation as resolved.
Show resolved Hide resolved
f.write(activate_source_lines)
with deactivate_path.open("w+", encoding="utf8") as f:
f.write(deactivate_source_lines)
CONSOLE.log(
f":person_gesturing_ok: Completions installed to {conda_path}. {_exclamation()} Reactivate the environment"
" to try them out."
)
else:
assert_never(mode)


def _get_conda_path() -> Optional[pathlib.Path]:
"""
Returns the path to the conda environment if
the nerfstudio package is installed in one.
"""
conda_path = None
if "CONDA_PREFIX" in os.environ:
# Conda is active, we will check if the Nerfstudio is installed in the conda env.
distribution = importlib_metadata.distribution("nerfstudio")
if str(distribution.locate_file("nerfstudio")).startswith(os.environ["CONDA_PREFIX"]):
conda_path = pathlib.Path(os.environ["CONDA_PREFIX"])
return conda_path


def _generate_completions_files(
completions_dir: pathlib.Path,
scripts_dir: pathlib.Path,
shells_supported: List[ShellType],
shells_found: List[ShellType],
) -> None:
# Set to True to install completions for scripts as well.
include_scripts = False

# Find tyro CLIs.
script_paths = list(filter(_check_tyro_cli, scripts_dir.glob("**/*.py"))) if include_scripts else []
script_names = tuple(p.name for p in script_paths)
assert len(set(script_names)) == len(script_names)

# Get existing completion files.
existing_completions = set()
for shell in shells_supported:
target_dir = completions_dir / shell
if target_dir.exists():
existing_completions |= set(target_dir.glob("*"))

# Get all entry_points.
entry_points = _get_all_entry_points()

# Run generation jobs.
concurrent_executor = concurrent.futures.ThreadPoolExecutor()
with CONSOLE.status("[bold]:writing_hand: Generating completions...", spinner="bouncingBall"):
completion_paths = list(
concurrent_executor.map(
lambda path_or_entrypoint_and_shell: _generate_completion(
path_or_entrypoint_and_shell[0], path_or_entrypoint_and_shell[1], completions_dir
),
itertools.product(script_paths + entry_points, shells_found),
)
)

# Delete obsolete completion files.
for unexpected_path in set(p.absolute() for p in existing_completions) - set(
p.absolute() for p in completion_paths
):
if unexpected_path.is_dir():
shutil.rmtree(unexpected_path)
elif unexpected_path.exists():
unexpected_path.unlink()
CONSOLE.log(f":broom: Deleted {unexpected_path}.")


def main(mode: ConfigureMode = "install") -> None:
"""Main script.

Expand All @@ -201,16 +338,24 @@ def main(mode: ConfigureMode = "install") -> None:
CONSOLE.log("[bold red]$HOME is not set. Exiting.")
return

# Get conda path if in conda environment.
conda_path = _get_conda_path()

# Try to locate the user's bashrc or zshrc.
shells_supported: List[ShellType] = ["zsh", "bash"]
shells_found: List[ShellType] = []
for shell in shells_supported:
rc_path = pathlib.Path(os.environ["HOME"]) / f".{shell}rc"
if not rc_path.exists():
CONSOLE.log(f":person_shrugging: {rc_path.name} not found, skipping.")
else:
CONSOLE.log(f":mag: Found {rc_path.name}!")
shells_found.append(shell)
shells_supported: List[ShellType] = list(typing_get_args(ShellType))
if conda_path is not None:
# Running in conda; we have to support all shells.
shells_found = shells_supported
CONSOLE.log(f":mag: Detected conda environment {conda_path}!")
else:
shells_found: List[ShellType] = []
for shell in shells_supported:
rc_path = pathlib.Path(os.environ["HOME"]) / f".{shell}rc"
if not rc_path.exists():
CONSOLE.log(f":person_shrugging: {rc_path.name} not found, skipping.")
else:
CONSOLE.log(f":mag: Found {rc_path.name}!")
shells_found.append(shell)

# Get scripts/ directory.
completions_dir = pathlib.Path(__file__).absolute().parent
Expand All @@ -230,48 +375,18 @@ def main(mode: ConfigureMode = "install") -> None:
else:
CONSOLE.log(f":heavy_check_mark: No existing completions at: {target_dir}.")
elif mode == "install":
# Set to True to install completions for scripts as well.
include_scripts = False

# Find tyro CLIs.
script_paths = list(filter(_check_tyro_cli, scripts_dir.glob("**/*.py"))) if include_scripts else []
script_names = tuple(p.name for p in script_paths)
assert len(set(script_names)) == len(script_names)

# Get existing completion files.
existing_completions = set()
for shell in shells_supported:
target_dir = completions_dir / shell
if target_dir.exists():
existing_completions |= set(target_dir.glob("*"))

# Run generation jobs.
concurrent_executor = concurrent.futures.ThreadPoolExecutor()
with CONSOLE.status("[bold]:writing_hand: Generating completions...", spinner="bouncingBall"):
completion_paths = list(
concurrent_executor.map(
lambda path_or_entrypoint_and_shell: _generate_completion(
path_or_entrypoint_and_shell[0], path_or_entrypoint_and_shell[1], completions_dir
),
itertools.product(script_paths + ENTRYPOINTS, shells_found),
)
)

# Delete obsolete completion files.
for unexpected_path in set(p.absolute() for p in existing_completions) - set(
p.absolute() for p in completion_paths
):
if unexpected_path.is_dir():
shutil.rmtree(unexpected_path)
elif unexpected_path.exists():
unexpected_path.unlink()
CONSOLE.log(f":broom: Deleted {unexpected_path}.")
_generate_completions_files(completions_dir, scripts_dir, shells_supported, shells_found)
else:
assert_never(mode)

# Install or uninstall from bashrc/zshrc.
for shell in shells_found:
_update_rc(completions_dir, mode, shell)
if conda_path is not None:
# In conda environment we add the completitions activation scripts.
commands = _get_all_entry_points()
_update_conda_scripts(commands, completions_dir, mode)
else:
# Install or uninstall from bashrc/zshrc.
for shell in shells_found:
_update_rc(completions_dir, mode, shell)

CONSOLE.print("[bold]All done![/bold]")

Expand Down