10
10
import stat
11
11
import subprocess
12
12
import sys
13
- from typing import List , Union
13
+ from typing import List , Optional , Union
14
14
15
15
import tyro
16
16
from rich .console import Console
17
17
from rich .prompt import Confirm
18
18
from typing_extensions import Literal , assert_never
19
+ from typing_extensions import get_args as typing_get_args
20
+
21
+ if sys .version_info < (3 , 10 ):
22
+ import importlib_metadata
23
+ else :
24
+ from importlib import metadata as importlib_metadata
19
25
20
26
ConfigureMode = Literal ["install" , "uninstall" ]
21
27
ShellType = Literal ["zsh" , "bash" ]
22
28
23
29
CONSOLE = Console (width = 120 )
30
+ HEADER_LINE = "# Source nerfstudio autocompletions."
31
+
24
32
25
- ENTRYPOINTS = [
26
- "ns-install-cli" ,
27
- "ns-process-data" ,
28
- "ns-download-data" ,
29
- "ns-train" ,
30
- "ns-eval" ,
31
- "ns-render" ,
32
- "ns-dev-test" ,
33
- ]
33
+ def _get_all_entry_points () -> List [str ]:
34
+ entry_points = importlib_metadata .distribution ("nerfstudio" ).entry_points
35
+ return [x .name for x in entry_points ]
34
36
35
37
36
38
def _check_tyro_cli (script_path : pathlib .Path ) -> bool :
@@ -127,26 +129,41 @@ def _exclamation() -> str:
127
129
return random .choice (["Cool" , "Nice" , "Neat" , "Great" , "Exciting" , "Excellent" , "Ok" ]) + "!"
128
130
129
131
130
- def _update_rc (
131
- completions_dir : pathlib .Path ,
132
- mode : ConfigureMode ,
133
- shell : ShellType ,
134
- ) -> None :
135
- """Try to add a `source /.../completions/setup.{shell}` line automatically to a user's zshrc or bashrc.
132
+ def _get_deactivate_script (commands : List [str ], shell : Optional [ShellType ], add_header = True ) -> str :
133
+ if shell is None :
134
+ # Install the universal script
135
+ result_script = []
136
+ for shell_type in typing_get_args (ShellType ):
137
+ result_script .append (f'if [ -n "${ shell_type .upper ()} _VERSION" ]; then' )
138
+ result_script .append (_get_deactivate_script (commands , shell_type , add_header = False ))
139
+ result_script .append ("fi" )
140
+ source_lines = "\n " .join (result_script )
136
141
137
- Args:
138
- completions_dir: Path to location of this script.
139
- shell: Shell to install completion scripts for.
140
- mode: Install or uninstall completions.
141
- """
142
+ elif shell == "zsh" :
143
+ source_lines = "\n " .join ([f"unset '_comps[{ command } ]' &> /dev/null" for command in commands ])
144
+ elif shell == "bash" :
145
+ source_lines = "\n " .join ([f"complete -r { command } &> /dev/null" for command in commands ])
146
+ else :
147
+ assert_never (shell )
142
148
143
- # Install or uninstall `source_line`.
144
- header_line = "# Source nerfstudio autocompletions."
145
- if shell == "zsh" :
149
+ if add_header :
150
+ source_lines = f"\n { HEADER_LINE } \n { source_lines } "
151
+ return source_lines
152
+
153
+
154
+ def _get_source_script (completions_dir : pathlib .Path , shell : Optional [ShellType ], add_header = True ) -> str :
155
+ if shell is None :
156
+ # Install the universal script
157
+ result_script = []
158
+ for shell_type in typing_get_args (ShellType ):
159
+ result_script .append (f'if [ -n "${ shell_type .upper ()} _VERSION" ]; then' )
160
+ result_script .append (_get_source_script (completions_dir , shell_type , add_header = False ))
161
+ result_script .append ("fi" )
162
+ source_lines = "\n " .join (result_script )
163
+
164
+ elif shell == "zsh" :
146
165
source_lines = "\n " .join (
147
166
[
148
- "" ,
149
- header_line ,
150
167
"if ! command -v compdef &> /dev/null; then" ,
151
168
" autoload -Uz compinit" ,
152
169
" compinit" ,
@@ -157,20 +174,38 @@ def _update_rc(
157
174
elif shell == "bash" :
158
175
source_lines = "\n " .join (
159
176
[
160
- "" ,
161
- header_line ,
162
177
f"source { completions_dir / 'setup.bash' } " ,
163
178
]
164
179
)
165
180
else :
166
181
assert_never (shell )
167
182
183
+ if add_header :
184
+ source_lines = f"\n { HEADER_LINE } \n { source_lines } "
185
+ return source_lines
186
+
187
+
188
+ def _update_rc (
189
+ completions_dir : pathlib .Path ,
190
+ mode : ConfigureMode ,
191
+ shell : ShellType ,
192
+ ) -> None :
193
+ """Try to add a `source /.../completions/setup.{shell}` line automatically to a user's zshrc or bashrc.
194
+
195
+ Args:
196
+ completions_dir: Path to location of this script.
197
+ shell: Shell to install completion scripts for.
198
+ mode: Install or uninstall completions.
199
+ """
200
+
201
+ # Install or uninstall `source_line`.
202
+ source_lines = _get_source_script (completions_dir , shell )
168
203
rc_path = pathlib .Path (os .environ ["HOME" ]) / f".{ shell } rc"
169
204
170
205
# Always try to uninstall previous completions.
171
206
rc_source = rc_path .read_text ()
172
- while header_line in rc_source :
173
- before_install , _ , after_install = rc_source .partition (header_line )
207
+ while HEADER_LINE in rc_source :
208
+ before_install , _ , after_install = rc_source .partition (HEADER_LINE )
174
209
source_file , _ , after_install = after_install .partition ("\n source " )[2 ].partition ("\n " )
175
210
assert source_file .endswith (f"/completions/setup.{ shell } " )
176
211
rc_source = before_install + after_install
@@ -189,6 +224,108 @@ def _update_rc(
189
224
assert mode == "uninstall"
190
225
191
226
227
+ def _update_conda_scripts (
228
+ commands : List [str ],
229
+ completions_dir : pathlib .Path ,
230
+ mode : ConfigureMode ,
231
+ ) -> None :
232
+ """Try to add a `source /.../completions/setup.{shell}` line automatically to conda's activation scripts.
233
+
234
+ Args:
235
+ completions_dir: Path to location of this script.
236
+ mode: Install or uninstall completions.
237
+ """
238
+
239
+ # Install or uninstall `source_line`.
240
+ activate_source_lines = _get_source_script (completions_dir , None )
241
+ deactivate_source_lines = _get_deactivate_script (commands , None )
242
+
243
+ conda_path = pathlib .Path (os .environ ["CONDA_PREFIX" ])
244
+ activate_path = conda_path / "etc/conda/activate.d/nerfstudio_activate.sh"
245
+ deactivate_path = conda_path / "etc/conda/deactivate.d/nerfstudio_deactivate.sh"
246
+ if mode == "uninstall" :
247
+ if activate_path .exists ():
248
+ os .remove (activate_path )
249
+ if deactivate_path .exists ():
250
+ os .remove (deactivate_path )
251
+ CONSOLE .log (f":broom: Existing completions uninstalled from { conda_path } ." )
252
+ elif mode == "install" :
253
+ # Install completions.
254
+ activate_path .parent .mkdir (exist_ok = True , parents = True )
255
+ deactivate_path .parent .mkdir (exist_ok = True , parents = True )
256
+ with activate_path .open ("w+" , encoding = "utf8" ) as f :
257
+ f .write (activate_source_lines )
258
+ with deactivate_path .open ("w+" , encoding = "utf8" ) as f :
259
+ f .write (deactivate_source_lines )
260
+ CONSOLE .log (
261
+ f":person_gesturing_ok: Completions installed to { conda_path } . { _exclamation ()} Reactivate the environment"
262
+ " to try them out."
263
+ )
264
+ else :
265
+ assert_never (mode )
266
+
267
+
268
+ def _get_conda_path () -> Optional [pathlib .Path ]:
269
+ """
270
+ Returns the path to the conda environment if
271
+ the nerfstudio package is installed in one.
272
+ """
273
+ conda_path = None
274
+ if "CONDA_PREFIX" in os .environ :
275
+ # Conda is active, we will check if the Nerfstudio is installed in the conda env.
276
+ distribution = importlib_metadata .distribution ("nerfstudio" )
277
+ if str (distribution .locate_file ("nerfstudio" )).startswith (os .environ ["CONDA_PREFIX" ]):
278
+ conda_path = pathlib .Path (os .environ ["CONDA_PREFIX" ])
279
+ return conda_path
280
+
281
+
282
+ def _generate_completions_files (
283
+ completions_dir : pathlib .Path ,
284
+ scripts_dir : pathlib .Path ,
285
+ shells_supported : List [ShellType ],
286
+ shells_found : List [ShellType ],
287
+ ) -> None :
288
+ # Set to True to install completions for scripts as well.
289
+ include_scripts = False
290
+
291
+ # Find tyro CLIs.
292
+ script_paths = list (filter (_check_tyro_cli , scripts_dir .glob ("**/*.py" ))) if include_scripts else []
293
+ script_names = tuple (p .name for p in script_paths )
294
+ assert len (set (script_names )) == len (script_names )
295
+
296
+ # Get existing completion files.
297
+ existing_completions = set ()
298
+ for shell in shells_supported :
299
+ target_dir = completions_dir / shell
300
+ if target_dir .exists ():
301
+ existing_completions |= set (target_dir .glob ("*" ))
302
+
303
+ # Get all entry_points.
304
+ entry_points = _get_all_entry_points ()
305
+
306
+ # Run generation jobs.
307
+ concurrent_executor = concurrent .futures .ThreadPoolExecutor ()
308
+ with CONSOLE .status ("[bold]:writing_hand: Generating completions..." , spinner = "bouncingBall" ):
309
+ completion_paths = list (
310
+ concurrent_executor .map (
311
+ lambda path_or_entrypoint_and_shell : _generate_completion (
312
+ path_or_entrypoint_and_shell [0 ], path_or_entrypoint_and_shell [1 ], completions_dir
313
+ ),
314
+ itertools .product (script_paths + entry_points , shells_found ),
315
+ )
316
+ )
317
+
318
+ # Delete obsolete completion files.
319
+ for unexpected_path in set (p .absolute () for p in existing_completions ) - set (
320
+ p .absolute () for p in completion_paths
321
+ ):
322
+ if unexpected_path .is_dir ():
323
+ shutil .rmtree (unexpected_path )
324
+ elif unexpected_path .exists ():
325
+ unexpected_path .unlink ()
326
+ CONSOLE .log (f":broom: Deleted { unexpected_path } ." )
327
+
328
+
192
329
def main (mode : ConfigureMode = "install" ) -> None :
193
330
"""Main script.
194
331
@@ -201,16 +338,24 @@ def main(mode: ConfigureMode = "install") -> None:
201
338
CONSOLE .log ("[bold red]$HOME is not set. Exiting." )
202
339
return
203
340
341
+ # Get conda path if in conda environment.
342
+ conda_path = _get_conda_path ()
343
+
204
344
# Try to locate the user's bashrc or zshrc.
205
- shells_supported : List [ShellType ] = ["zsh" , "bash" ]
206
- shells_found : List [ShellType ] = []
207
- for shell in shells_supported :
208
- rc_path = pathlib .Path (os .environ ["HOME" ]) / f".{ shell } rc"
209
- if not rc_path .exists ():
210
- CONSOLE .log (f":person_shrugging: { rc_path .name } not found, skipping." )
211
- else :
212
- CONSOLE .log (f":mag: Found { rc_path .name } !" )
213
- shells_found .append (shell )
345
+ shells_supported : List [ShellType ] = list (typing_get_args (ShellType ))
346
+ if conda_path is not None :
347
+ # Running in conda; we have to support all shells.
348
+ shells_found = shells_supported
349
+ CONSOLE .log (f":mag: Detected conda environment { conda_path } !" )
350
+ else :
351
+ shells_found : List [ShellType ] = []
352
+ for shell in shells_supported :
353
+ rc_path = pathlib .Path (os .environ ["HOME" ]) / f".{ shell } rc"
354
+ if not rc_path .exists ():
355
+ CONSOLE .log (f":person_shrugging: { rc_path .name } not found, skipping." )
356
+ else :
357
+ CONSOLE .log (f":mag: Found { rc_path .name } !" )
358
+ shells_found .append (shell )
214
359
215
360
# Get scripts/ directory.
216
361
completions_dir = pathlib .Path (__file__ ).absolute ().parent
@@ -230,48 +375,18 @@ def main(mode: ConfigureMode = "install") -> None:
230
375
else :
231
376
CONSOLE .log (f":heavy_check_mark: No existing completions at: { target_dir } ." )
232
377
elif mode == "install" :
233
- # Set to True to install completions for scripts as well.
234
- include_scripts = False
235
-
236
- # Find tyro CLIs.
237
- script_paths = list (filter (_check_tyro_cli , scripts_dir .glob ("**/*.py" ))) if include_scripts else []
238
- script_names = tuple (p .name for p in script_paths )
239
- assert len (set (script_names )) == len (script_names )
240
-
241
- # Get existing completion files.
242
- existing_completions = set ()
243
- for shell in shells_supported :
244
- target_dir = completions_dir / shell
245
- if target_dir .exists ():
246
- existing_completions |= set (target_dir .glob ("*" ))
247
-
248
- # Run generation jobs.
249
- concurrent_executor = concurrent .futures .ThreadPoolExecutor ()
250
- with CONSOLE .status ("[bold]:writing_hand: Generating completions..." , spinner = "bouncingBall" ):
251
- completion_paths = list (
252
- concurrent_executor .map (
253
- lambda path_or_entrypoint_and_shell : _generate_completion (
254
- path_or_entrypoint_and_shell [0 ], path_or_entrypoint_and_shell [1 ], completions_dir
255
- ),
256
- itertools .product (script_paths + ENTRYPOINTS , shells_found ),
257
- )
258
- )
259
-
260
- # Delete obsolete completion files.
261
- for unexpected_path in set (p .absolute () for p in existing_completions ) - set (
262
- p .absolute () for p in completion_paths
263
- ):
264
- if unexpected_path .is_dir ():
265
- shutil .rmtree (unexpected_path )
266
- elif unexpected_path .exists ():
267
- unexpected_path .unlink ()
268
- CONSOLE .log (f":broom: Deleted { unexpected_path } ." )
378
+ _generate_completions_files (completions_dir , scripts_dir , shells_supported , shells_found )
269
379
else :
270
380
assert_never (mode )
271
381
272
- # Install or uninstall from bashrc/zshrc.
273
- for shell in shells_found :
274
- _update_rc (completions_dir , mode , shell )
382
+ if conda_path is not None :
383
+ # In conda environment we add the completitions activation scripts.
384
+ commands = _get_all_entry_points ()
385
+ _update_conda_scripts (commands , completions_dir , mode )
386
+ else :
387
+ # Install or uninstall from bashrc/zshrc.
388
+ for shell in shells_found :
389
+ _update_rc (completions_dir , mode , shell )
275
390
276
391
CONSOLE .print ("[bold]All done![/bold]" )
277
392
0 commit comments