From 7981ba9ce702cc8a8590205fcdf64cedc59df997 Mon Sep 17 00:00:00 2001 From: viniciusdc Date: Wed, 11 Sep 2024 18:43:47 -0300 Subject: [PATCH] fix missing var name & fix deployment bug & rm validation restrictions --- src/_nebari/stages/infrastructure/__init__.py | 43 +++++++++++-------- .../infrastructure/template/aws/main.tf | 1 - .../aws/modules/kubernetes/variables.tf | 6 --- .../infrastructure/template/aws/variables.tf | 11 ----- .../stages/terraform_state/__init__.py | 8 +++- 5 files changed, 32 insertions(+), 37 deletions(-) diff --git a/src/_nebari/stages/infrastructure/__init__.py b/src/_nebari/stages/infrastructure/__init__.py index 701efad3e..bd36d0619 100644 --- a/src/_nebari/stages/infrastructure/__init__.py +++ b/src/_nebari/stages/infrastructure/__init__.py @@ -132,13 +132,6 @@ class AWSNodeLaunchTemplate(schema.Base): pre_bootstrap_command: Optional[str] = None ami_id: Optional[str] = None - @field_validator("ami_id") - @classmethod - def _validate_ami_id(cls, value: Optional[str]) -> str: - if value is None: - raise ValueError("ami_id is required if pre_bootstrap_command is passed") - return value - class AWSNodeGroupInputVars(schema.Base): name: str @@ -150,9 +143,28 @@ class AWSNodeGroupInputVars(schema.Base): single_subnet: bool permissions_boundary: Optional[str] = None launch_template: Optional[AWSNodeLaunchTemplate] = None - ami_type: Optional[Literal["AL2_x86_64", "AL2_x86_64_GPU", "CUSTOM"]] = Field( - "AL2_x86_64", exclude=True - ) + ami_type: Optional[str] = None + + @field_validator("ami_type", mode="before") + @classmethod + def _infer_and_validate_ami_type(cls, value, values) -> str: + gpu_enabled = values.get("gpu", False) + + # Auto-set ami_type if not provided + if not value: + if values.get("launch_template") and values["launch_template"].ami_id: + return "CUSTOM" + if gpu_enabled: + return "AL2_x86_64_GPU" + return "AL2_x86_64" + + # Explicit validation + if value == "AL2_x86_64" and gpu_enabled: + raise ValueError( + "ami_type 'AL2_x86_64' cannot be used with GPU enabled (gpu=True)." + ) + + return value class AWSInputVars(schema.Base): @@ -162,7 +174,6 @@ class AWSInputVars(schema.Base): existing_subnet_ids: Optional[List[str]] = None region: str kubernetes_version: str - node_launch_template: Optional[AWSNodeLaunchTemplate] = None eks_endpoint_access: Optional[ Literal["private", "public", "public_and_private"] ] = "public" @@ -467,6 +478,7 @@ class AWSNodeGroup(schema.Base): gpu: bool = False single_subnet: bool = False permissions_boundary: Optional[str] = None + launch_template: Optional[AWSNodeLaunchTemplate] = None DEFAULT_AWS_NODE_GROUPS = { @@ -849,13 +861,8 @@ def input_vars(self, stage_outputs: Dict[str, Dict[str, Any]]): permissions_boundary=node_group.permissions_boundary, launch_template=( self.config.amazon_web_services.node_launch_template - if not node_group.node_launch_template - else node_group.node_launch_template - ), - ami_type=( - node_group.ami_type - if not node_group.gpu - else "AL2_x86_64_GPU" + if not node_group.launch_template + else node_group.launch_template ), ) for name, node_group in self.config.amazon_web_services.node_groups.items() diff --git a/src/_nebari/stages/infrastructure/template/aws/main.tf b/src/_nebari/stages/infrastructure/template/aws/main.tf index 2b561ba04..feffd3529 100644 --- a/src/_nebari/stages/infrastructure/template/aws/main.tf +++ b/src/_nebari/stages/infrastructure/template/aws/main.tf @@ -97,7 +97,6 @@ module "kubernetes" { node_groups = var.node_groups - node_launch_template = var.node_launch_template endpoint_public_access = var.eks_endpoint_access == "private" ? false : true endpoint_private_access = var.eks_endpoint_access == "public" ? false : true public_access_cidrs = var.eks_public_access_cidrs diff --git a/src/_nebari/stages/infrastructure/template/aws/modules/kubernetes/variables.tf b/src/_nebari/stages/infrastructure/template/aws/modules/kubernetes/variables.tf index 7cdedcdba..4d38d10a1 100644 --- a/src/_nebari/stages/infrastructure/template/aws/modules/kubernetes/variables.tf +++ b/src/_nebari/stages/infrastructure/template/aws/modules/kubernetes/variables.tf @@ -62,12 +62,6 @@ variable "node_group_instance_type" { default = "m5.large" } -variable "node_launch_template" { - description = "Custom launch template for EKS nodes" - type = map(any) - default = null -} - variable "endpoint_public_access" { type = bool default = true diff --git a/src/_nebari/stages/infrastructure/template/aws/variables.tf b/src/_nebari/stages/infrastructure/template/aws/variables.tf index 794d7eb31..a3f37b9eb 100644 --- a/src/_nebari/stages/infrastructure/template/aws/variables.tf +++ b/src/_nebari/stages/infrastructure/template/aws/variables.tf @@ -43,11 +43,6 @@ variable "node_groups" { })) } -variable "node_launch_template" { - description = "Custom launch template for EKS nodes (placeholder)" - type = map(any) -} - variable "availability_zones" { description = "AWS availability zones within AWS region" type = list(string) @@ -63,12 +58,6 @@ variable "kubeconfig_filename" { type = string } -variable "node_launch_template" { - description = "Custom launch template for EKS nodes" - type = string - default = null -} - variable "eks_endpoint_access" { description = "EKS cluster api server endpoint access setting" type = string diff --git a/src/_nebari/stages/terraform_state/__init__.py b/src/_nebari/stages/terraform_state/__init__.py index d9afff36e..97fb62652 100644 --- a/src/_nebari/stages/terraform_state/__init__.py +++ b/src/_nebari/stages/terraform_state/__init__.py @@ -2,6 +2,7 @@ import enum import functools import inspect +import json import os import pathlib import re @@ -261,11 +262,16 @@ def check_immutable_fields(self): nebari_config_diff = utils.JsonDiff( nebari_config_state.model_dump(), self.config.model_dump() ) - + # save both for testing: + with open("nebari_config_state.json", "w") as f: + f.write(json.dumps(nebari_config_state.model_dump(), indent=4)) + with open("nebari_config.json", "w") as f: + f.write(json.dumps(self.config.model_dump(), indent=4)) # check if any changed fields are immutable for keys, old, new in nebari_config_diff.modified(): bottom_level_schema = self.config if len(keys) > 1: + print(keys) bottom_level_schema = functools.reduce( lambda m, k: getattr(m, k), keys[:-1], self.config )