Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 93 additions & 116 deletions scripts/bootstrap_resource_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,16 @@
import subprocess
import tempfile
from dataclasses import dataclass
from distutils.util import strtobool
from typing import Optional

import yaml

from gigl.common import GcsUri, HttpUri, LocalUri, UriFactory
from gigl.common import GcsUri, UriFactory
from gigl.src.common.utils.file_loader import FileLoader

GIGL_ROOT_DIR = pathlib.Path(__file__).resolve().parent.parent
LOCAL_DEV_TEMPLATE_RES_CONF = LocalUri(
GIGL_ROOT_DIR / "deployment" / "configs" / "unittest_resource_config.yaml"
)
FALLBACK_TEMPLATE_RES_CONF = HttpUri(
uri="https://raw.githubusercontent.com/Snapchat/GiGL/refs/heads/main/deployment/configs/unittest_resource_config.yaml"
)
CURR_DATETIME = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
CURR_USERNAME = getpass.getuser()


@dataclass
Expand All @@ -31,6 +26,35 @@ class Param:
required: bool = True


@dataclass
class TemplateConfig:
template_path: pathlib.Path
env_var_to_update: str

def __post_init__(self):
if not self.template_path.exists():
raise FileNotFoundError(f"Template file not found: {self.template_path}")


# Template configurations to process
TEMPLATE_CONFIGS = {
"default": TemplateConfig(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we name this tabularized?

template_path=GIGL_ROOT_DIR
/ "deployment"
/ "configs"
/ "e2e_cicd_resource_config.yaml",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just pointing out some inconsistency in naming re: cicd vs. glt. not sure if this is appropriate / intended

env_var_to_update="GIGL_TEST_DEFAULT_RESOURCE_CONFIG",
),
"in_memory": TemplateConfig(
template_path=GIGL_ROOT_DIR
/ "deployment"
/ "configs"
/ "e2e_glt_resource_config.yaml",
env_var_to_update="GIGL_TEST_IN_MEMORY_DEFAULT_RESOURCE_CONFIG",
),
}


class SupportedParams:
def __init__(self):
try:
Expand Down Expand Up @@ -75,22 +99,18 @@ def __init__(self):
default=None,
description="`perm_assets_bucket` - GCS Bucket for storing permanent assets i.e. `gs://YOUR_BUCKET_NAME`",
),
"template_resource_config_uri": Param(
default=None,
description="URI to the template resource config file to use for bootstrapping. If provided, will be used as the 'Base' for the resource config, with the appropriate fields overwritten by the values provided in this script.",
),
"output_resource_config_path": Param(
default=None,
required=False,
description="Path to the output resource config file. If not provided, one will be generated in the `perm_assets_bucket`.",
),
"force_shell_config_update": Param(
default="False",
description="If set to True, will not ask to update the shell configuration file. If False, will prompt the user to update the shell configuration file.",
),
}


def compute_resource_config_destination_path(
name: str,
perm_assets_bucket: str,
) -> GcsUri:
return GcsUri(
f"{perm_assets_bucket}/{CURR_USERNAME}/{CURR_DATETIME}/{name}_resource_config.yaml"
)


def infer_shell_file() -> str:
"""Infers the user's default shell configuration file."""
shell = os.environ.get("SHELL", "")
Expand All @@ -114,22 +134,25 @@ def infer_shell_file() -> str:

def update_shell_config(
shell_config_path: str,
gigl_test_default_resource_config: str,
gigl_project: str,
gigl_docker_artifact_registry_path: str,
shell_env_vars: dict[str, str],
):
"""Updates the shell configuration file with the environment variables in an idempotent way."""
shell_config_path = os.path.expanduser(shell_config_path)
start_marker = "# ====== GiGL ENV Config - Begin ====="
end_marker = "# ====== GiGL ENV Config - End ====="
export_lines = [
start_marker + "\n",
"# This section is auto-generated by GiGL/scripts/bootstrap_resource_config.py.\n",
f'export GIGL_TEST_DEFAULT_RESOURCE_CONFIG="{gigl_test_default_resource_config}"\n',
f'export GIGL_PROJECT="{gigl_project}"\n',
f'export GIGL_DOCKER_ARTIFACT_REGISTRY="{gigl_docker_artifact_registry_path}"\n',
end_marker + "\n",
f'export {key}="{value}"\n' for key, value in shell_env_vars.items()
]
export_lines = (
[
start_marker + "\n",
"# This section is auto-generated by GiGL/scripts/bootstrap_resource_config.py.\n",
]
+ export_lines
+ [
end_marker + "\n",
]
)

Comment on lines -127 to 156
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this whole bit is a bit confusing with the different variables and hard-coded strings and lists.

Can we just use one f-string?

# Read the existing shell config file
if not os.path.exists(shell_config_path):
Expand Down Expand Up @@ -243,25 +266,6 @@ def assert_gcs_bucket_exists(bucket_name: str):
raise ValueError(
f"Missing required value for {key}. Please provide a value."
)
resource_config_path: str
file_loader = FileLoader()
if args.template_resource_config_uri:
print(
f"Using provided template resource config: {args.template_resource_config_uri}"
)
resource_config_path = args.template_resource_config_uri
elif file_loader.does_uri_exist(uri=LOCAL_DEV_TEMPLATE_RES_CONF):
print(
f"Using local development template resource config: {LOCAL_DEV_TEMPLATE_RES_CONF}"
)
resource_config_path = LOCAL_DEV_TEMPLATE_RES_CONF.uri
else:
print(f"Using fallback template resource config: {FALLBACK_TEMPLATE_RES_CONF}")
tmp_file = file_loader.load_to_temp_file(
file_uri_src=FALLBACK_TEMPLATE_RES_CONF
)
print(f"Downloaded fallback template resource config to {tmp_file.name}")
resource_config_path = tmp_file.name

# Validate existence of resources
assert_gcp_project_exists(values["project"])
Expand All @@ -279,29 +283,6 @@ def assert_gcs_bucket_exists(bucket_name: str):
assert_gcs_bucket_exists(bucket_name=values["temp_assets_bucket"])
assert_gcs_bucket_exists(bucket_name=values["perm_assets_bucket"])

curr_datetime = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
curr_username = getpass.getuser()
default_resource_config_dest_path = f"{values['perm_assets_bucket']}/{curr_username}/{curr_datetime}/gigl_test_default_resource_config.yaml"

if args.output_resource_config_path:
destination_file_path = args.output_resource_config_path
else:
destination_file_path = (
input(
f"Output path for resource config (default: {default_resource_config_dest_path}); Must be a GCS path starting with 'gs://': "
).strip()
or default_resource_config_dest_path
)

file_uri_dest = UriFactory.create_uri(uri=destination_file_path)
if not isinstance(file_uri_dest, GcsUri):
raise ValueError(
f"Destination file path must be a GCS URI starting with 'gs://'. Please provide a valid GCS path. Recived URI: {file_uri_dest.uri}. We do this prevent leaking sensitive information in the local filesystem."
)

print("=======================================================")
print(f"Will now create the resource config file @ {destination_file_path}.")
print("Using the following values:")
update_fields_dict = {
"project": values["project"],
"region": values["region"],
Expand All @@ -312,55 +293,51 @@ def assert_gcs_bucket_exists(bucket_name: str):
"temp_assets_bq_dataset_name": values["temp_assets_bq_dataset_name"],
"embedding_bq_dataset_name": values["embedding_bq_dataset_name"],
}
print("=======================================================")
print("Will now create the resource config files:")
for key, value in update_fields_dict.items():
print(f" {key}: {value}")

with open(resource_config_path, "r") as file:
config = yaml.safe_load(file)

# Update the YAML content
common_compute_config: dict = config.get("shared_resource_config").get(
"common_compute_config"
)
common_compute_config.update(update_fields_dict)

tmp_file = tempfile.NamedTemporaryFile(delete=False)
with open(tmp_file.name, "w") as file:
yaml.safe_dump(config, file)

file_loader = FileLoader(project=values["project"])
file_uri_src = UriFactory.create_uri(uri=tmp_file.name)
file_loader.load_file(file_uri_src=file_uri_src, file_uri_dst=file_uri_dest)

print(f"Updated YAML file saved at '{destination_file_path}'")

# Update the user's shell configuration
if args.force_shell_config_update and strtobool(args.force_shell_config_update):
should_update_shell_config = "y"
print("Forcing shell updated due to --force_shell_config_update flag.")
else:
should_update_shell_config = (
input(
"Do you want to update your shell configuration file so you can use this new resource config for tests? [y/n] (Default: y): "
)
.strip()
.lower()
or "y"
)
if should_update_shell_config == "y":
shell_config_path: str = infer_shell_file()
update_shell_config(
shell_config_path=shell_config_path,
gigl_test_default_resource_config=destination_file_path,
gigl_project=values["project"],
gigl_docker_artifact_registry_path=values["docker_artifact_registry_path"],
shell_env_vars = {
"GIGL_PROJECT": values["project"],
"GIGL_DOCKER_ARTIFACT_REGISTRY": values["docker_artifact_registry_path"],
}
for name, template_config in TEMPLATE_CONFIGS.items():
resource_config_destination_path = compute_resource_config_destination_path(
name=name, perm_assets_bucket=values["perm_assets_bucket"]
)

print(
f"Please restart your shell or run `source {shell_config_path}` to apply the changes."
with open(template_config.template_path, "r") as file:
config = yaml.safe_load(file)

# Update the YAML content
common_compute_config: dict = config.get("shared_resource_config").get(
"common_compute_config"
)
else:
print(
"Skipping shell configuration update. Please remember to set the environment variables manually "
+ "if you want `make unit_test | integration_test` commands to work correctly."
common_compute_config.update(update_fields_dict)

tmp_file = tempfile.NamedTemporaryFile(delete=False)
with open(tmp_file.name, "w") as file:
yaml.safe_dump(config, file)

file_loader = FileLoader(project=values["project"])
file_uri_src = UriFactory.create_uri(uri=tmp_file.name)
file_loader.load_file(
file_uri_src=file_uri_src, file_uri_dst=resource_config_destination_path
)
shell_env_vars[
template_config.env_var_to_update
Comment on lines +326 to +330
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw should we log that we created these configs somehwere?

] = resource_config_destination_path.uri

# Update the user's shell configuration (always)
print("Updating shell configuration...")
shell_config_path: str = infer_shell_file()
update_shell_config(
shell_config_path=shell_config_path,
shell_env_vars=shell_env_vars,
)
Comment on lines +336 to +339
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, can we keep this update shell config bit as an option? I feel like I would not want to update my shell config when I do this (e.g. when I create some new configs for oss but want to keep the internal config for personal use)


print(
f"Shell configuration updated. Please restart your shell or run `source {shell_config_path}` to apply the changes."
)