Skip to content

Commit 97498e3

Browse files
authored
Improve ns-install-cli (#1396)
* Improve ns-install-cli * Remove unused code * ns-install-cli conda specify encoding explicitly * Fix missing directory in ns-install-cli * Implement ns-install-cli deactivate for conda env * ns-install-cli: conda create deactivate.d
1 parent 78eeb3b commit 97498e3

File tree

1 file changed

+193
-78
lines changed

1 file changed

+193
-78
lines changed

scripts/completions/install.py

+193-78
Original file line numberDiff line numberDiff line change
@@ -10,27 +10,29 @@
1010
import stat
1111
import subprocess
1212
import sys
13-
from typing import List, Union
13+
from typing import List, Optional, Union
1414

1515
import tyro
1616
from rich.console import Console
1717
from rich.prompt import Confirm
1818
from typing_extensions import Literal, assert_never
19+
from typing_extensions import get_args as typing_get_args
20+
21+
if sys.version_info < (3, 10):
22+
import importlib_metadata
23+
else:
24+
from importlib import metadata as importlib_metadata
1925

2026
ConfigureMode = Literal["install", "uninstall"]
2127
ShellType = Literal["zsh", "bash"]
2228

2329
CONSOLE = Console(width=120)
30+
HEADER_LINE = "# Source nerfstudio autocompletions."
31+
2432

25-
ENTRYPOINTS = [
26-
"ns-install-cli",
27-
"ns-process-data",
28-
"ns-download-data",
29-
"ns-train",
30-
"ns-eval",
31-
"ns-render",
32-
"ns-dev-test",
33-
]
33+
def _get_all_entry_points() -> List[str]:
34+
entry_points = importlib_metadata.distribution("nerfstudio").entry_points
35+
return [x.name for x in entry_points]
3436

3537

3638
def _check_tyro_cli(script_path: pathlib.Path) -> bool:
@@ -127,26 +129,41 @@ def _exclamation() -> str:
127129
return random.choice(["Cool", "Nice", "Neat", "Great", "Exciting", "Excellent", "Ok"]) + "!"
128130

129131

130-
def _update_rc(
131-
completions_dir: pathlib.Path,
132-
mode: ConfigureMode,
133-
shell: ShellType,
134-
) -> None:
135-
"""Try to add a `source /.../completions/setup.{shell}` line automatically to a user's zshrc or bashrc.
132+
def _get_deactivate_script(commands: List[str], shell: Optional[ShellType], add_header=True) -> str:
133+
if shell is None:
134+
# Install the universal script
135+
result_script = []
136+
for shell_type in typing_get_args(ShellType):
137+
result_script.append(f'if [ -n "${shell_type.upper()}_VERSION" ]; then')
138+
result_script.append(_get_deactivate_script(commands, shell_type, add_header=False))
139+
result_script.append("fi")
140+
source_lines = "\n".join(result_script)
136141

137-
Args:
138-
completions_dir: Path to location of this script.
139-
shell: Shell to install completion scripts for.
140-
mode: Install or uninstall completions.
141-
"""
142+
elif shell == "zsh":
143+
source_lines = "\n".join([f"unset '_comps[{command}]' &> /dev/null" for command in commands])
144+
elif shell == "bash":
145+
source_lines = "\n".join([f"complete -r {command} &> /dev/null" for command in commands])
146+
else:
147+
assert_never(shell)
142148

143-
# Install or uninstall `source_line`.
144-
header_line = "# Source nerfstudio autocompletions."
145-
if shell == "zsh":
149+
if add_header:
150+
source_lines = f"\n{HEADER_LINE}\n{source_lines}"
151+
return source_lines
152+
153+
154+
def _get_source_script(completions_dir: pathlib.Path, shell: Optional[ShellType], add_header=True) -> str:
155+
if shell is None:
156+
# Install the universal script
157+
result_script = []
158+
for shell_type in typing_get_args(ShellType):
159+
result_script.append(f'if [ -n "${shell_type.upper()}_VERSION" ]; then')
160+
result_script.append(_get_source_script(completions_dir, shell_type, add_header=False))
161+
result_script.append("fi")
162+
source_lines = "\n".join(result_script)
163+
164+
elif shell == "zsh":
146165
source_lines = "\n".join(
147166
[
148-
"",
149-
header_line,
150167
"if ! command -v compdef &> /dev/null; then",
151168
" autoload -Uz compinit",
152169
" compinit",
@@ -157,20 +174,38 @@ def _update_rc(
157174
elif shell == "bash":
158175
source_lines = "\n".join(
159176
[
160-
"",
161-
header_line,
162177
f"source {completions_dir / 'setup.bash'}",
163178
]
164179
)
165180
else:
166181
assert_never(shell)
167182

183+
if add_header:
184+
source_lines = f"\n{HEADER_LINE}\n{source_lines}"
185+
return source_lines
186+
187+
188+
def _update_rc(
189+
completions_dir: pathlib.Path,
190+
mode: ConfigureMode,
191+
shell: ShellType,
192+
) -> None:
193+
"""Try to add a `source /.../completions/setup.{shell}` line automatically to a user's zshrc or bashrc.
194+
195+
Args:
196+
completions_dir: Path to location of this script.
197+
shell: Shell to install completion scripts for.
198+
mode: Install or uninstall completions.
199+
"""
200+
201+
# Install or uninstall `source_line`.
202+
source_lines = _get_source_script(completions_dir, shell)
168203
rc_path = pathlib.Path(os.environ["HOME"]) / f".{shell}rc"
169204

170205
# Always try to uninstall previous completions.
171206
rc_source = rc_path.read_text()
172-
while header_line in rc_source:
173-
before_install, _, after_install = rc_source.partition(header_line)
207+
while HEADER_LINE in rc_source:
208+
before_install, _, after_install = rc_source.partition(HEADER_LINE)
174209
source_file, _, after_install = after_install.partition("\nsource ")[2].partition("\n")
175210
assert source_file.endswith(f"/completions/setup.{shell}")
176211
rc_source = before_install + after_install
@@ -189,6 +224,108 @@ def _update_rc(
189224
assert mode == "uninstall"
190225

191226

227+
def _update_conda_scripts(
228+
commands: List[str],
229+
completions_dir: pathlib.Path,
230+
mode: ConfigureMode,
231+
) -> None:
232+
"""Try to add a `source /.../completions/setup.{shell}` line automatically to conda's activation scripts.
233+
234+
Args:
235+
completions_dir: Path to location of this script.
236+
mode: Install or uninstall completions.
237+
"""
238+
239+
# Install or uninstall `source_line`.
240+
activate_source_lines = _get_source_script(completions_dir, None)
241+
deactivate_source_lines = _get_deactivate_script(commands, None)
242+
243+
conda_path = pathlib.Path(os.environ["CONDA_PREFIX"])
244+
activate_path = conda_path / "etc/conda/activate.d/nerfstudio_activate.sh"
245+
deactivate_path = conda_path / "etc/conda/deactivate.d/nerfstudio_deactivate.sh"
246+
if mode == "uninstall":
247+
if activate_path.exists():
248+
os.remove(activate_path)
249+
if deactivate_path.exists():
250+
os.remove(deactivate_path)
251+
CONSOLE.log(f":broom: Existing completions uninstalled from {conda_path}.")
252+
elif mode == "install":
253+
# Install completions.
254+
activate_path.parent.mkdir(exist_ok=True, parents=True)
255+
deactivate_path.parent.mkdir(exist_ok=True, parents=True)
256+
with activate_path.open("w+", encoding="utf8") as f:
257+
f.write(activate_source_lines)
258+
with deactivate_path.open("w+", encoding="utf8") as f:
259+
f.write(deactivate_source_lines)
260+
CONSOLE.log(
261+
f":person_gesturing_ok: Completions installed to {conda_path}. {_exclamation()} Reactivate the environment"
262+
" to try them out."
263+
)
264+
else:
265+
assert_never(mode)
266+
267+
268+
def _get_conda_path() -> Optional[pathlib.Path]:
269+
"""
270+
Returns the path to the conda environment if
271+
the nerfstudio package is installed in one.
272+
"""
273+
conda_path = None
274+
if "CONDA_PREFIX" in os.environ:
275+
# Conda is active, we will check if the Nerfstudio is installed in the conda env.
276+
distribution = importlib_metadata.distribution("nerfstudio")
277+
if str(distribution.locate_file("nerfstudio")).startswith(os.environ["CONDA_PREFIX"]):
278+
conda_path = pathlib.Path(os.environ["CONDA_PREFIX"])
279+
return conda_path
280+
281+
282+
def _generate_completions_files(
283+
completions_dir: pathlib.Path,
284+
scripts_dir: pathlib.Path,
285+
shells_supported: List[ShellType],
286+
shells_found: List[ShellType],
287+
) -> None:
288+
# Set to True to install completions for scripts as well.
289+
include_scripts = False
290+
291+
# Find tyro CLIs.
292+
script_paths = list(filter(_check_tyro_cli, scripts_dir.glob("**/*.py"))) if include_scripts else []
293+
script_names = tuple(p.name for p in script_paths)
294+
assert len(set(script_names)) == len(script_names)
295+
296+
# Get existing completion files.
297+
existing_completions = set()
298+
for shell in shells_supported:
299+
target_dir = completions_dir / shell
300+
if target_dir.exists():
301+
existing_completions |= set(target_dir.glob("*"))
302+
303+
# Get all entry_points.
304+
entry_points = _get_all_entry_points()
305+
306+
# Run generation jobs.
307+
concurrent_executor = concurrent.futures.ThreadPoolExecutor()
308+
with CONSOLE.status("[bold]:writing_hand: Generating completions...", spinner="bouncingBall"):
309+
completion_paths = list(
310+
concurrent_executor.map(
311+
lambda path_or_entrypoint_and_shell: _generate_completion(
312+
path_or_entrypoint_and_shell[0], path_or_entrypoint_and_shell[1], completions_dir
313+
),
314+
itertools.product(script_paths + entry_points, shells_found),
315+
)
316+
)
317+
318+
# Delete obsolete completion files.
319+
for unexpected_path in set(p.absolute() for p in existing_completions) - set(
320+
p.absolute() for p in completion_paths
321+
):
322+
if unexpected_path.is_dir():
323+
shutil.rmtree(unexpected_path)
324+
elif unexpected_path.exists():
325+
unexpected_path.unlink()
326+
CONSOLE.log(f":broom: Deleted {unexpected_path}.")
327+
328+
192329
def main(mode: ConfigureMode = "install") -> None:
193330
"""Main script.
194331
@@ -201,16 +338,24 @@ def main(mode: ConfigureMode = "install") -> None:
201338
CONSOLE.log("[bold red]$HOME is not set. Exiting.")
202339
return
203340

341+
# Get conda path if in conda environment.
342+
conda_path = _get_conda_path()
343+
204344
# Try to locate the user's bashrc or zshrc.
205-
shells_supported: List[ShellType] = ["zsh", "bash"]
206-
shells_found: List[ShellType] = []
207-
for shell in shells_supported:
208-
rc_path = pathlib.Path(os.environ["HOME"]) / f".{shell}rc"
209-
if not rc_path.exists():
210-
CONSOLE.log(f":person_shrugging: {rc_path.name} not found, skipping.")
211-
else:
212-
CONSOLE.log(f":mag: Found {rc_path.name}!")
213-
shells_found.append(shell)
345+
shells_supported: List[ShellType] = list(typing_get_args(ShellType))
346+
if conda_path is not None:
347+
# Running in conda; we have to support all shells.
348+
shells_found = shells_supported
349+
CONSOLE.log(f":mag: Detected conda environment {conda_path}!")
350+
else:
351+
shells_found: List[ShellType] = []
352+
for shell in shells_supported:
353+
rc_path = pathlib.Path(os.environ["HOME"]) / f".{shell}rc"
354+
if not rc_path.exists():
355+
CONSOLE.log(f":person_shrugging: {rc_path.name} not found, skipping.")
356+
else:
357+
CONSOLE.log(f":mag: Found {rc_path.name}!")
358+
shells_found.append(shell)
214359

215360
# Get scripts/ directory.
216361
completions_dir = pathlib.Path(__file__).absolute().parent
@@ -230,48 +375,18 @@ def main(mode: ConfigureMode = "install") -> None:
230375
else:
231376
CONSOLE.log(f":heavy_check_mark: No existing completions at: {target_dir}.")
232377
elif mode == "install":
233-
# Set to True to install completions for scripts as well.
234-
include_scripts = False
235-
236-
# Find tyro CLIs.
237-
script_paths = list(filter(_check_tyro_cli, scripts_dir.glob("**/*.py"))) if include_scripts else []
238-
script_names = tuple(p.name for p in script_paths)
239-
assert len(set(script_names)) == len(script_names)
240-
241-
# Get existing completion files.
242-
existing_completions = set()
243-
for shell in shells_supported:
244-
target_dir = completions_dir / shell
245-
if target_dir.exists():
246-
existing_completions |= set(target_dir.glob("*"))
247-
248-
# Run generation jobs.
249-
concurrent_executor = concurrent.futures.ThreadPoolExecutor()
250-
with CONSOLE.status("[bold]:writing_hand: Generating completions...", spinner="bouncingBall"):
251-
completion_paths = list(
252-
concurrent_executor.map(
253-
lambda path_or_entrypoint_and_shell: _generate_completion(
254-
path_or_entrypoint_and_shell[0], path_or_entrypoint_and_shell[1], completions_dir
255-
),
256-
itertools.product(script_paths + ENTRYPOINTS, shells_found),
257-
)
258-
)
259-
260-
# Delete obsolete completion files.
261-
for unexpected_path in set(p.absolute() for p in existing_completions) - set(
262-
p.absolute() for p in completion_paths
263-
):
264-
if unexpected_path.is_dir():
265-
shutil.rmtree(unexpected_path)
266-
elif unexpected_path.exists():
267-
unexpected_path.unlink()
268-
CONSOLE.log(f":broom: Deleted {unexpected_path}.")
378+
_generate_completions_files(completions_dir, scripts_dir, shells_supported, shells_found)
269379
else:
270380
assert_never(mode)
271381

272-
# Install or uninstall from bashrc/zshrc.
273-
for shell in shells_found:
274-
_update_rc(completions_dir, mode, shell)
382+
if conda_path is not None:
383+
# In conda environment we add the completitions activation scripts.
384+
commands = _get_all_entry_points()
385+
_update_conda_scripts(commands, completions_dir, mode)
386+
else:
387+
# Install or uninstall from bashrc/zshrc.
388+
for shell in shells_found:
389+
_update_rc(completions_dir, mode, shell)
275390

276391
CONSOLE.print("[bold]All done![/bold]")
277392

0 commit comments

Comments
 (0)