Skip to content

Commit

Permalink
Flesh out extend_schema method, update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
iameskild committed Jun 26, 2023
1 parent 11ac6b9 commit 309672b
Show file tree
Hide file tree
Showing 8 changed files with 122 additions and 53 deletions.
2 changes: 2 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ addopts =
--tb=native
# turn warnings into errors
-Werror
# ignore deprecation warnings (TODO: filter further)
-W ignore::DeprecationWarning
markers =
conda: conda required to run this test (deselect with '-m \"not conda\"')
testpaths =
Expand Down
65 changes: 45 additions & 20 deletions src/_nebari/subcommands/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import questionary
import rich
import typer
from pydantic import BaseModel

from _nebari.initialize import render_config
from nebari import schema
Expand Down Expand Up @@ -45,19 +46,10 @@ def enum_to_list(enum_cls):
return [e.value for e in enum_cls]


def handle_init(inputs: schema.InitInputs):
def handle_init(inputs: schema.InitInputs, config_schema: BaseModel):
"""
Take the inputs from the `nebari init` command, render the config and write it to a local yaml file.
"""
# if NEBARI_IMAGE_TAG:
# print(
# f"Modifying the image tags for the `default_images`, setting tags to: {NEBARI_IMAGE_TAG}"
# )

# if NEBARI_DASK_VERSION:
# print(
# f"Modifying the version of the `nebari_dask` package, setting version to: {NEBARI_DASK_VERSION}"
# )

# this will force the `set_kubernetes_version` to grab the latest version
if inputs.kubernetes_version == "latest":
Expand All @@ -80,7 +72,9 @@ def handle_init(inputs: schema.InitInputs):
)

try:
schema.write_configuration(pathlib.Path("nebari-config.yaml"), config, mode="x")
schema.write_configuration(
inputs.output, config, mode="x", config_schema=config_schema
)
except FileExistsError:
raise ValueError(
"A nebari-config.yaml file already exists. Please move or delete it and try again."
Expand All @@ -104,6 +98,24 @@ def check_ssl_cert_email(ctx: typer.Context, ssl_cert_email: str):
return ssl_cert_email


def check_repository_creds(ctx: typer.Context, git_provider: str):
"""Validate the necessary Git provider (GitHub) credentials are set."""

if (
git_provider == schema.GitRepoEnum.github.value.lower()
and not os.environ.get("GITHUB_USERNAME")
or not os.environ.get("GITHUB_TOKEN")
):
os.environ["GITHUB_USERNAME"] = typer.prompt(
"Paste your GITHUB_USERNAME",
hide_input=True,
)
os.environ["GITHUB_TOKEN"] = typer.prompt(
"Paste your GITHUB_TOKEN",
hide_input=True,
)


def check_auth_provider_creds(ctx: typer.Context, auth_provider: str):
"""Validate the the necessary auth provider credentials have been set as environment variables."""
if ctx.params.get("disable_prompt"):
Expand Down Expand Up @@ -333,6 +345,12 @@ def init(
False,
is_eager=True,
),
output: str = typer.Option(
pathlib.Path("nebari-config.yaml"),
"--output",
"-o",
help="Output file path for the rendered config file.",
),
):
"""
Create and initialize your [purple]nebari-config.yaml[/purple] file.
Expand Down Expand Up @@ -363,8 +381,13 @@ def init(
inputs.kubernetes_version = kubernetes_version
inputs.ssl_cert_email = ssl_cert_email
inputs.disable_prompt = disable_prompt
inputs.output = output

handle_init(inputs)
from nebari.plugins import nebari_plugin_manager

handle_init(inputs, config_schema=nebari_plugin_manager.config_schema)

nebari_plugin_manager.load_config(output)


def guided_init_wizard(ctx: typer.Context, guided_init: str):
Expand Down Expand Up @@ -522,7 +545,7 @@ def guided_init_wizard(ctx: typer.Context, guided_init: str):

git_provider = questionary.select(
"Which git provider would you like to use?",
choices=enum_to_list(GitRepoEnum),
choices=enum_to_list(schema.GitRepoEnum),
qmark=qmark,
).unsafe_ask()

Expand All @@ -540,7 +563,7 @@ def guided_init_wizard(ctx: typer.Context, guided_init: str):
git_provider=git_provider, org_name=org_name, repo_name=repo_name
)

if git_provider == GitRepoEnum.github.value.lower():
if git_provider == schema.GitRepoEnum.github.value.lower():
inputs.repository_auto_provision = questionary.confirm(
f"Would you like nebari to create a remote repository on {git_provider}?",
default=False,
Expand All @@ -551,10 +574,10 @@ def guided_init_wizard(ctx: typer.Context, guided_init: str):
if not disable_checks and inputs.repository_auto_provision:
check_repository_creds(ctx, git_provider)

if git_provider == GitRepoEnum.github.value.lower():
inputs.ci_provider = CiEnum.github_actions.value.lower()
elif git_provider == GitRepoEnum.gitlab.value.lower():
inputs.ci_provider = CiEnum.gitlab_ci.value.lower()
if git_provider == schema.GitRepoEnum.github.value.lower():
inputs.ci_provider = schema.CiEnum.github_actions.value.lower()
elif git_provider == schema.GitRepoEnum.gitlab.value.lower():
inputs.ci_provider = schema.CiEnum.gitlab_ci.value.lower()

# SSL CERTIFICATE
if inputs.domain_name:
Expand Down Expand Up @@ -598,7 +621,7 @@ def guided_init_wizard(ctx: typer.Context, guided_init: str):
# TERRAFORM STATE
inputs.terraform_state = questionary.select(
"Where should the Terraform State be provisioned?",
choices=enum_to_list(TerraformStateEnum),
choices=enum_to_list(schema.TerraformStateEnum),
qmark=qmark,
).unsafe_ask()

Expand All @@ -615,7 +638,9 @@ def guided_init_wizard(ctx: typer.Context, guided_init: str):
qmark=qmark,
).unsafe_ask()

handle_init(inputs)
from nebari.plugins import nebari_plugin_manager

handle_init(inputs, config_schema=nebari_plugin_manager.config_schema)

rich.print(
(
Expand Down
6 changes: 5 additions & 1 deletion src/_nebari/subcommands/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,9 @@ def validate(
# comment_on_pr(config)
pass
else:
schema.read_configuration(config_filename)
from nebari.plugins import nebari_plugin_manager

schema.read_configuration(
config_filename, config_schema=nebari_plugin_manager.config_schema
)
print("[bold purple]Successfully validated configuration.[/bold purple]")
4 changes: 2 additions & 2 deletions src/nebari/__main__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from nebari.plugins import Nebari
from nebari.plugins import nebari_plugin_manager


def main():
Nebari()
nebari_plugin_manager.create_cli()


if __name__ == "__main__":
Expand Down
6 changes: 4 additions & 2 deletions src/nebari/hookspecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import typer
from pluggy import HookimplMarker, HookspecMarker
from pydantic import BaseModel

from nebari import schema

Expand All @@ -12,8 +13,9 @@


class NebariStage:
name = None
priority = None
name: str = None
priority: int = None
stage_schema: BaseModel = None

def __init__(self, output_directory: pathlib.Path, config: schema.Main):
self.output_directory = output_directory
Expand Down
66 changes: 44 additions & 22 deletions src/nebari/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from typer.core import TyperGroup

from _nebari.version import __version__
from nebari import hookspecs
from nebari import hookspecs, schema

DEFAULT_SUBCOMMAND_PLUGINS = [
# subcommands
Expand All @@ -31,6 +31,7 @@
]

DEFAULT_STAGES_PLUGINS = [
# stages
"_nebari.stages.bootstrap",
"_nebari.stages.terraform_state",
"_nebari.stages.infrastructure",
Expand All @@ -43,7 +44,7 @@
]


class Nebari:
class NebariPluginManager:
plugin_manager = pluggy.PluginManager("nebari")

ordered_stages: typing.List[hookspecs.NebariStage] = []
Expand All @@ -53,7 +54,8 @@ class Nebari:
cli: typer.Typer = None

schema_name: str = "NebariConfig"
config_path: typing.Union[str, Path, None] = None
config_schema: typing.Union[BaseModel, None] = None
config_path: typing.Union[Path, None] = None
config: typing.Union[BaseModel, None] = None

def __init__(self) -> None:
Expand All @@ -63,10 +65,9 @@ def __init__(self) -> None:
# Only load plugins if not running tests
self.plugin_manager.load_setuptools_entrypoints("nebari")

# create and start CLI
self.load_subcommands(DEFAULT_SUBCOMMAND_PLUGINS)
self.cli = self._create_cli()
self.cli()
self.ordered_stages = self.get_available_stages()
self.config_schema = self.extend_schema()

def load_subcommands(self, subcommand: typing.List[str]):
self._load_plugins(subcommand)
Expand Down Expand Up @@ -124,32 +125,45 @@ def get_available_stages(self):

return included_stages

def create_dynamic_schema(self, plugin: BaseModel, base: BaseModel) -> BaseModel:
extra_fields = {}
for n, f in plugin.__fields__.items():
extra_fields[n] = (f.type_, f.default if f.default is not None else ...)
def load_config(self, config_path: typing.Union[str, Path]):
if isinstance(config_path, str):
config_path = Path(config_path)

if not config_path.exists():
raise FileNotFoundError(f"Config file {config_path} not found")

self.config_path = config_path
self.config = schema.read_configuration(config_path)

def _create_dynamic_schema(
self, base: BaseModel, stage: BaseModel, stage_name: str
) -> BaseModel:
stage_fields = {
n: (f.type_, f.default if f.default is not None else ...)
for n, f in stage.__fields__.items()
}
# ensure top-level key for `stage` is set to `stage_name`
stage_model = create_model(stage_name, __base__=schema.Base, **stage_fields)
extra_fields = {stage_name: (stage_model, None)}
return create_model(self.schema_name, __base__=base, **extra_fields)

def extend_schema(self, base_schema: BaseModel) -> BaseModel:
if not self.config:
return

for stages in self.ordered_stages():
def extend_schema(self, base_schema: BaseModel = schema.Main) -> BaseModel:
config_schema = base_schema
for stages in self.ordered_stages:
if stages.stage_schema:
self.config = self.create_dynamic_schema(
stages.stage_schema, base_schema
config_schema = self._create_dynamic_schema(
config_schema,
stages.stage_schema,
stages.name,
)

if self.config:
return self.config
return base_schema
return config_schema

def _version_callback(self, value: bool):
if value:
typer.echo(__version__)
raise typer.Exit()

def _create_cli(self) -> typer.Typer:
def create_cli(self) -> typer.Typer:
class OrderCommands(TyperGroup):
def list_commands(self, ctx: typer.Context):
"""Return list of commands in the order appear."""
Expand Down Expand Up @@ -208,6 +222,7 @@ def common(
self.exclude_default_stages = exclude_default_stages
self.exclude_stages = excluded_stages
self.ordered_stages = self.get_available_stages()
self.config_schema = self.extend_schema()

@cli.command()
def info(ctx: typer.Context):
Expand Down Expand Up @@ -245,4 +260,11 @@ def info(ctx: typer.Context):
rich.print(table)

self.plugin_manager.hook.nebari_subcommand(cli=cli)

self.cli = cli
self.cli()

return cli


nebari_plugin_manager = NebariPluginManager()
12 changes: 9 additions & 3 deletions src/nebari/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,7 @@ class InitInputs(Base):
kubernetes_version: typing.Union[str, None] = None
ssl_cert_email: typing.Union[str, None] = None
disable_prompt: bool = False
output: pathlib.Path = pathlib.Path("nebari-config.yaml")


class CLIContext(Base):
Expand Down Expand Up @@ -1101,7 +1102,11 @@ def set_config_from_environment_variables(
return config


def read_configuration(config_filename: pathlib.Path, read_environment: bool = True):
def read_configuration(
config_filename: pathlib.Path,
read_environment: bool = True,
config_schema: pydantic.BaseModel = Main,
):
"""Read configuration from multiple sources and apply validation"""
filename = pathlib.Path(config_filename)

Expand All @@ -1115,7 +1120,7 @@ def read_configuration(config_filename: pathlib.Path, read_environment: bool = T
)

with filename.open() as f:
config = Main(**yaml.load(f.read()))
config = config_schema(**yaml.load(f.read()))

if read_environment:
config = set_config_from_environment_variables(config)
Expand All @@ -1127,13 +1132,14 @@ def write_configuration(
config_filename: pathlib.Path,
config: typing.Union[Main, typing.Dict],
mode: str = "w",
config_schema: pydantic.BaseModel = Main,
):
yaml = YAML()
yaml.preserve_quotes = True
yaml.default_flow_style = False

with config_filename.open(mode) as f:
if isinstance(config, Main):
if isinstance(config, config_schema):
yaml.dump(config.dict(), f)
else:
yaml.dump(config, f)
Loading

0 comments on commit 309672b

Please sign in to comment.