diff --git a/src/_nebari/__main__.py b/src/_nebari/__main__.py deleted file mode 100644 index b18eaf428..000000000 --- a/src/_nebari/__main__.py +++ /dev/null @@ -1,10 +0,0 @@ -from _nebari.cli import create_cli - - -def main(): - cli = create_cli() - cli() - - -if __name__ == "__main__": - main() diff --git a/src/_nebari/cli.py b/src/_nebari/cli.py deleted file mode 100644 index 0ebef8f1d..000000000 --- a/src/_nebari/cli.py +++ /dev/null @@ -1,101 +0,0 @@ -import typing - -import typer -from typer.core import TyperGroup - -from _nebari.version import __version__ -from nebari import schema -from nebari.plugins import load_plugins, pm - - -class OrderCommands(TyperGroup): - def list_commands(self, ctx: typer.Context): - """Return list of commands in the order appear.""" - return list(self.commands) - - -def version_callback(value: bool): - if value: - typer.echo(__version__) - raise typer.Exit() - - -def exclude_stages(ctx: typer.Context, stages: typing.List[str]): - ctx.ensure_object(schema.CLIContext) - ctx.obj.excluded_stages = stages - return stages - - -def exclude_default_stages(ctx: typer.Context, exclude_default_stages: bool): - ctx.ensure_object(schema.CLIContext) - ctx.obj.exclude_default_stages = exclude_default_stages - return exclude_default_stages - - -def import_plugin(plugins: typing.List[str]): - try: - load_plugins(plugins) - except ModuleNotFoundError: - typer.echo( - "ERROR: Python module {e.name} not found. Make sure that the module is in your python path {sys.path}" - ) - typer.Exit() - return plugins - - -def create_cli(): - app = typer.Typer( - cls=OrderCommands, - help="Nebari CLI 🪴", - add_completion=False, - no_args_is_help=True, - rich_markup_mode="rich", - pretty_exceptions_show_locals=False, - context_settings={"help_option_names": ["-h", "--help"]}, - ) - - @app.callback() - def common( - ctx: typer.Context, - version: bool = typer.Option( - None, - "-V", - "--version", - help="Nebari version number", - callback=version_callback, - ), - plugins: typing.List[str] = typer.Option( - [], - "--import-plugin", - help="Import nebari plugin", - ), - excluded_stages: typing.List[str] = typer.Option( - [], - "--exclude-stage", - help="Exclude nebari stage(s) by name or regex", - ), - exclude_default_stages: bool = typer.Option( - False, - "--exclude-default-stages", - help="Exclude default nebari included stages", - ), - ): - try: - load_plugins(plugins) - except ModuleNotFoundError: - typer.echo( - "ERROR: Python module {e.name} not found. Make sure that the module is in your python path {sys.path}" - ) - typer.Exit() - - from _nebari.stages.base import get_available_stages - - ctx.ensure_object(schema.CLIContext) - ctx.obj.stages = get_available_stages( - exclude_default_stages=exclude_default_stages, - exclude_stages=excluded_stages, - ) - - pm.hook.nebari_subcommand(cli=app) - - return app diff --git a/src/_nebari/stages/base.py b/src/_nebari/stages/base.py index bd4f5199b..8f603250d 100644 --- a/src/_nebari/stages/base.py +++ b/src/_nebari/stages/base.py @@ -1,9 +1,7 @@ import contextlib import inspect -import itertools import os import pathlib -import re from typing import Any, Dict, List, Tuple from _nebari.provider import terraform @@ -107,49 +105,3 @@ def destroy( if not ignore_errors: raise e status["stages/" + self.name] = False - - -def get_available_stages( - exclude_default_stages: bool = False, exclude_stages: List[str] = [] -): - from nebari.plugins import load_plugins, pm - - DEFAULT_STAGES = [ - "_nebari.stages.bootstrap", - "_nebari.stages.terraform_state", - "_nebari.stages.infrastructure", - "_nebari.stages.kubernetes_initialize", - "_nebari.stages.kubernetes_ingress", - "_nebari.stages.kubernetes_keycloak", - "_nebari.stages.kubernetes_keycloak_configuration", - "_nebari.stages.kubernetes_services", - "_nebari.stages.nebari_tf_extensions", - ] - - if not exclude_default_stages: - load_plugins(DEFAULT_STAGES) - - stages = itertools.chain.from_iterable(pm.hook.nebari_stage()) - - # order stages by priority - sorted_stages = sorted(stages, key=lambda s: s.priority) - - # filter out duplicate stages with same name (keep highest priority) - visited_stage_names = set() - filtered_stages = [] - for stage in reversed(sorted_stages): - if stage.name in visited_stage_names: - continue - filtered_stages.insert(0, stage) - visited_stage_names.add(stage.name) - - # filter out stages which match excluded stages - included_stages = [] - for stage in filtered_stages: - for exclude_stage in exclude_stages: - if re.fullmatch(exclude_stage, stage.name) is not None: - break - else: - included_stages.append(stage) - - return included_stages diff --git a/src/_nebari/subcommands/info.py b/src/_nebari/subcommands/info.py deleted file mode 100644 index f49f3da90..000000000 --- a/src/_nebari/subcommands/info.py +++ /dev/null @@ -1,42 +0,0 @@ -import collections - -import rich -import typer -from rich.table import Table - -from _nebari.version import __version__ -from nebari.hookspecs import hookimpl -from nebari.plugins import pm - - -@hookimpl -def nebari_subcommand(cli: typer.Typer): - @cli.command() - def info(ctx: typer.Context): - rich.print(f"Nebari version: {__version__}") - - hooks = collections.defaultdict(list) - for plugin in pm.get_plugins(): - for hook in pm.get_hookcallers(plugin): - hooks[hook.name].append(plugin.__name__) - - table = Table(title="Hooks") - table.add_column("hook", justify="left", no_wrap=True) - table.add_column("module", justify="left", no_wrap=True) - - for hook_name, modules in hooks.items(): - for module in modules: - table.add_row(hook_name, module) - - rich.print(table) - - table = Table(title="Runtime Stage Ordering") - table.add_column("name") - table.add_column("priority") - table.add_column("module") - for stage in ctx.obj.stages: - table.add_row( - stage.name, str(stage.priority), f"{stage.__module__}.{stage.__name__}" - ) - - rich.print(table) diff --git a/src/nebari/__main__.py b/src/nebari/__main__.py index b18eaf428..33f35f36a 100644 --- a/src/nebari/__main__.py +++ b/src/nebari/__main__.py @@ -1,9 +1,8 @@ -from _nebari.cli import create_cli +from nebari.plugins import Nebari def main(): - cli = create_cli() - cli() + Nebari() if __name__ == "__main__": diff --git a/src/nebari/plugins.py b/src/nebari/plugins.py index 37fdc7bae..0a62e3917 100644 --- a/src/nebari/plugins.py +++ b/src/nebari/plugins.py @@ -1,10 +1,20 @@ +import collections import importlib +import itertools import os +import re import sys import typing +from pathlib import Path import pluggy +import rich +import typer +from pydantic import BaseModel, create_model +from rich.table import Table +from typer.core import TyperGroup +from _nebari.version import __version__ from nebari import hookspecs DEFAULT_SUBCOMMAND_PLUGINS = [ @@ -18,38 +28,221 @@ "_nebari.subcommands.support", "_nebari.subcommands.upgrade", "_nebari.subcommands.validate", - "_nebari.subcommands.info", ] -pm = pluggy.PluginManager("nebari") -pm.add_hookspecs(hookspecs) +DEFAULT_STAGES_PLUGINS = [ + "_nebari.stages.bootstrap", + "_nebari.stages.terraform_state", + "_nebari.stages.infrastructure", + "_nebari.stages.kubernetes_initialize", + "_nebari.stages.kubernetes_ingress", + "_nebari.stages.kubernetes_keycloak", + "_nebari.stages.kubernetes_keycloak_configuration", + "_nebari.stages.kubernetes_services", + "_nebari.stages.nebari_tf_extensions", +] + + +class Nebari: + plugin_manager = pluggy.PluginManager("nebari") + + ordered_stages: typing.List[hookspecs.NebariStage] = [] + exclude_default_stages: bool = False + exclude_stages: typing.List[str] = [] + + cli: typer.Typer = None + + schema_name: str = "NebariConfig" + config_path: typing.Union[str, Path, None] = None + config: typing.Union[BaseModel, None] = None + + def __init__(self) -> None: + self.plugin_manager.add_hookspecs(hookspecs) + + if not hasattr(sys, "_called_from_test"): + # Only load plugins if not running tests + self.plugin_manager.load_setuptools_entrypoints("nebari") + + # create and start CLI + self.load_subcommands(DEFAULT_SUBCOMMAND_PLUGINS) + self.cli = self._create_cli() + self.cli() + + def load_subcommands(self, subcommand: typing.List[str]): + self._load_plugins(subcommand) + + def load_stages(self, stages: typing.List[str]): + self._load_plugins(stages) + + def _load_plugins(self, plugins: typing.List[str]): + def _import_module_from_filename(plugin: str): + module_name = f"_nebari.stages._files.{plugin.replace(os.sep, '.')}" + spec = importlib.util.spec_from_file_location(module_name, plugin) + mod = importlib.util.module_from_spec(spec) + sys.modules[module_name] = mod + spec.loader.exec_module(mod) + return mod + + for plugin in plugins: + if plugin.endswith(".py"): + mod = _import_module_from_filename(plugin) + else: + mod = importlib.import_module(plugin) + + try: + self.plugin_manager.register(mod, plugin) + except ValueError: + # Pluin already registered + pass + + def get_available_stages(self): + if not self.exclude_default_stages: + self.load_stages(DEFAULT_STAGES_PLUGINS) + + stages = itertools.chain.from_iterable(self.plugin_manager.hook.nebari_stage()) + + # order stages by priority + sorted_stages = sorted(stages, key=lambda s: s.priority) + + # filter out duplicate stages with same name (keep highest priority) + visited_stage_names = set() + filtered_stages = [] + for stage in reversed(sorted_stages): + if stage.name in visited_stage_names: + continue + filtered_stages.insert(0, stage) + visited_stage_names.add(stage.name) + + # filter out stages which match excluded stages + included_stages = [] + for stage in filtered_stages: + for exclude_stage in self.exclude_stages: + if re.fullmatch(exclude_stage, stage.name) is not None: + break + else: + included_stages.append(stage) + + return included_stages + + def create_dynamic_schema(self, plugin: BaseModel, base: BaseModel) -> BaseModel: + extra_fields = {} + for n, f in plugin.__fields__.items(): + extra_fields[n] = (f.type_, f.default if f.default is not None else ...) + return create_model(self.schema_name, __base__=base, **extra_fields) + + def extend_schema(self, base_schema: BaseModel) -> BaseModel: + if not self.config: + return + + for stages in self.ordered_stages(): + if stages.stage_schema: + self.config = self.create_dynamic_schema( + stages.stage_schema, base_schema + ) + + if self.config: + return self.config + return base_schema + + def _version_callback(self, value: bool): + if value: + typer.echo(__version__) + raise typer.Exit() + + def _create_cli(self) -> typer.Typer: + class OrderCommands(TyperGroup): + def list_commands(self, ctx: typer.Context): + """Return list of commands in the order appear.""" + return list(self.commands) + + cli = typer.Typer( + cls=OrderCommands, + help="Nebari CLI 🪴", + add_completion=False, + no_args_is_help=True, + rich_markup_mode="rich", + pretty_exceptions_show_locals=False, + context_settings={"help_option_names": ["-h", "--help"]}, + ) + + @cli.callback() + def common( + ctx: typer.Context, + version: bool = typer.Option( + None, + "-V", + "--version", + help="Nebari version number", + callback=self._version_callback, + ), + extra_stages: typing.List[str] = typer.Option( + [], + "--import-plugin", + help="Import nebari plugin", + ), + extra_subcommands: typing.List[str] = typer.Option( + [], + "--import-subcommand", + help="Import nebari subcommand", + ), + excluded_stages: typing.List[str] = typer.Option( + [], + "--exclude-stage", + help="Exclude nebari stage(s) by name or regex", + ), + exclude_default_stages: bool = typer.Option( + False, + "--exclude-default-stages", + help="Exclude default nebari included stages", + ), + ): + try: + self.load_stages(extra_stages) + self.load_subcommands(extra_subcommands) + except ModuleNotFoundError: + typer.echo( + "ERROR: Python module {e.name} not found. Make sure that the module is in your python path {sys.path}" + ) + typer.Exit() + + self.exclude_default_stages = exclude_default_stages + self.exclude_stages = excluded_stages + self.ordered_stages = self.get_available_stages() + + @cli.command() + def info(ctx: typer.Context): + """ + Display the version and available hooks for Nebari. + """ + rich.print(f"Nebari version: {__version__}") -if not hasattr(sys, "_called_from_test"): - # Only load plugins if not running tests - pm.load_setuptools_entrypoints("nebari") + hooks = collections.defaultdict(list) + for plugin in self.plugin_manager.get_plugins(): + for hook in self.plugin_manager.get_hookcallers(plugin): + hooks[hook.name].append(plugin.__name__) + table = Table(title="Hooks") + table.add_column("hook", justify="left", no_wrap=True) + table.add_column("module", justify="left", no_wrap=True) -# Load default plugins -def load_plugins(plugins: typing.List[str]): - def _import_module_from_filename(filename: str): - module_name = f"_nebari.stages._files.{plugin.replace(os.sep, '.')}" - spec = importlib.util.spec_from_file_location(module_name, plugin) - mod = importlib.util.module_from_spec(spec) - sys.modules[module_name] = mod - spec.loader.exec_module(mod) - return mod + for hook_name, modules in hooks.items(): + for module in modules: + table.add_row(hook_name, module) - for plugin in plugins: - if plugin.endswith(".py"): - mod = _import_module_from_filename(plugin) - else: - mod = importlib.import_module(plugin) + rich.print(table) - try: - pm.register(mod, plugin) - except ValueError: - # Pluin already registered - pass + table = Table(title="Runtime Stage Ordering") + table.add_column("name") + table.add_column("priority") + table.add_column("module") + for stage in self.ordered_stages: + table.add_row( + stage.name, + str(stage.priority), + f"{stage.__module__}.{stage.__name__}", + ) + rich.print(table) -load_plugins(DEFAULT_SUBCOMMAND_PLUGINS) + self.plugin_manager.hook.nebari_subcommand(cli=cli) + return cli