Skip to content

Commit

Permalink
Move setup back to aqueduct start (#197)
Browse files Browse the repository at this point in the history
  • Loading branch information
eunice-chan authored Jul 7, 2022
1 parent 328c0c3 commit 54e4420
Show file tree
Hide file tree
Showing 2 changed files with 200 additions and 240 deletions.
205 changes: 200 additions & 5 deletions src/python/bin/aqueduct
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@

import argparse
import os
import string
import random
import subprocess
import sys
import platform
import shutil
import requests
import zipfile
import platform
import socket
import subprocess
Expand All @@ -19,6 +27,11 @@ ui_directory = os.path.join(os.environ["HOME"], ".aqueduct", "ui")
package_version = "0.0.3"
default_server_port = 8080

s3_server_prefix = (
"https://aqueduct-ai.s3.us-east-2.amazonaws.com/assets/%s/server" % package_version
)
s3_ui_prefix = "https://aqueduct-ai.s3.us-east-2.amazonaws.com/assets/%s/ui" % package_version

welcome_message = """
***************************************************
Your API Key: %s
Expand All @@ -27,6 +40,190 @@ The Web UI and the backend server are accessible at: http://%s:%d
***************************************************
"""


def update_config_yaml(file):
s = string.ascii_uppercase + string.digits
encryption_key = "".join(random.sample(s, 32))
api_key = "".join(random.sample(s, 32))

with open(file, "r") as sources:
lines = sources.readlines()
with open(file, "w") as sources:
for line in lines:
if "<BASE_PATH>" in line:
sources.write(line.replace("<BASE_PATH>", server_directory))
elif "<ENCRYPTION_KEY>" in line:
sources.write(line.replace("<ENCRYPTION_KEY>", encryption_key))
elif "<API_KEY>" in line:
sources.write(line.replace("<API_KEY>", api_key))
else:
sources.write(line)
print("Updated configurations.")


def execute_command(args, cwd=None):
with subprocess.Popen(args, stdout=sys.stdout, stderr=sys.stderr, cwd=cwd) as proc:
proc.communicate()
if proc.returncode != 0:
raise Exception("Error executing command: %s" % args)


def generate_version_file(file_path):
with open(file_path, "w") as f:
f.write(package_version)


# Returns a bool indicating whether we need to perform a version upgrade.
def require_update(file_path):
if not os.path.isfile(file_path):
return True
with open(file_path, "r") as f:
current_version = f.read()
if package_version < current_version:
raise Exception(
"Attempting to install an older version %s but found existing newer version %s"
% (package_version, current_version)
)
elif package_version == current_version:
return False
else:
return True


def update_executable_permissions():
os.chmod(os.path.join(server_directory, "bin", "server"), 0o755)
os.chmod(os.path.join(server_directory, "bin", "executor"), 0o755)
os.chmod(os.path.join(server_directory, "bin", "migrator"), 0o755)


def download_server_binaries(architecture):
with open(os.path.join(server_directory, "bin/server"), "wb") as f:
f.write(requests.get(os.path.join(s3_server_prefix, f"bin/{architecture}/server")).content)
with open(os.path.join(server_directory, "bin/executor"), "wb") as f:
f.write(
requests.get(os.path.join(s3_server_prefix, f"bin/{architecture}/executor")).content
)
with open(os.path.join(server_directory, "bin/migrator"), "wb") as f:
f.write(
requests.get(os.path.join(s3_server_prefix, f"bin/{architecture}/migrator")).content
)
with open(os.path.join(server_directory, "bin/start-function-executor.sh"), "wb") as f:
f.write(
requests.get(os.path.join(s3_server_prefix, "bin/start-function-executor.sh")).content
)
with open(os.path.join(server_directory, "bin/install_sqlserver_ubuntu.sh"), "wb") as f:
f.write(
requests.get(os.path.join(s3_server_prefix, "bin/install_sqlserver_ubuntu.sh")).content
)
print("Downloaded server binaries.")


def setup_server_binaries():
print("Downloading server binaries.")
server_bin_path = os.path.join(server_directory, "bin")
shutil.rmtree(server_bin_path, ignore_errors=True)
os.mkdir(server_bin_path)

system = platform.system()
arch = platform.machine()
if system == "Linux" and arch == "x86_64":
print("Operating system is Linux with architecture amd64.")
download_server_binaries("linux_amd64")
elif system == "Darwin" and arch == "x86_64":
print("Operating system is Mac with architecture amd64.")
download_server_binaries("darwin_amd64")
elif system == "Darwin" and arch == "arm64":
print("Operating system is Mac with architecture arm64.")
download_server_binaries("darwin_arm64")
else:
raise Exception(
"Unsupported operating system and architecture combination: %s, %s" % (system, arch)
)


def update_ui_version():
print("Updating UI version to %s" % package_version)
try:
shutil.rmtree(ui_directory, ignore_errors=True)
os.mkdir(ui_directory)
generate_version_file(os.path.join(ui_directory, "__version__"))
ui_zip_path = os.path.join(ui_directory, "ui.zip")
with open(ui_zip_path, "wb") as f:
f.write(requests.get(os.path.join(s3_ui_prefix, "ui.zip")).content)
with zipfile.ZipFile(ui_zip_path, "r") as zip:
zip.extractall(ui_directory)
os.remove(ui_zip_path)
except Exception as e:
print(e)
shutil.rmtree(ui_directory, ignore_errors=True)
exit(1)


def update_server_version():
print("Updating server version to %s" % package_version)

version_file = os.path.join(server_directory, "__version__")
if os.path.isfile(version_file):
os.remove(version_file)
generate_version_file(version_file)

setup_server_binaries()
update_executable_permissions()

execute_command(
[os.path.join(server_directory, "bin", "migrator"), "--type", "sqlite", "goto", "9"]
)


def update():
if not os.path.isdir(base_directory):
os.makedirs(base_directory)

if not os.path.isdir(ui_directory) or require_update(os.path.join(ui_directory, "__version__")):
update_ui_version()

if not os.path.isdir(server_directory):
try:
directories = [
server_directory,
os.path.join(server_directory, "db"),
os.path.join(server_directory, "storage"),
os.path.join(server_directory, "storage", "operators"),
os.path.join(server_directory, "vault"),
os.path.join(server_directory, "bin"),
os.path.join(server_directory, "config"),
os.path.join(server_directory, "logs"),
]

for directory in directories:
os.mkdir(directory)

update_server_version()

with open(os.path.join(server_directory, "config/config.yml"), "wb") as f:
f.write(requests.get(os.path.join(s3_server_prefix, "config/config.yml")).content)

update_config_yaml(os.path.join(server_directory, "config", "config.yml"))

with open(os.path.join(server_directory, "db/demo.db"), "wb") as f:
f.write(requests.get(os.path.join(s3_server_prefix, "db/demo.db")).content)

print("Finished initializing Aqueduct base directory.")
except Exception as e:
print(e)
shutil.rmtree(server_directory, ignore_errors=True)
exit(1)

version_file = os.path.join(server_directory, "__version__")
if require_update(version_file):
try:
update_server_version()
except Exception as e:
print(e)
if os.path.isfile(version_file):
os.remove(version_file)
exit(1)

def execute_command(args, cwd=None):
with subprocess.Popen(args, stdout=sys.stdout, stderr=sys.stderr, cwd=cwd) as proc:
proc.communicate()
Expand All @@ -48,7 +245,9 @@ def is_port_in_use(port: int) -> bool:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(("localhost", port)) == 0

def start(expose, port):
def start(expose, port):
update()

if port is None:
server_port = default_server_port
while is_port_in_use(server_port):
Expand Down Expand Up @@ -204,10 +403,6 @@ if __name__ == "__main__":

args = parser.parse_args()

if not os.path.isdir(base_directory):
print("Please install or update aqueduct-ml with `pip install aqueduct-ml`.")
sys.exit(1)

if args.command == "start":
try:
popen_handle, server_port = start(args.expose, args.port)
Expand Down
Loading

0 comments on commit 54e4420

Please sign in to comment.