Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
flake8 . --ignore=E203,W503,E722,E731 --max-complexity=100 --max-line-length=160
- name: Lint with pyright (type checking)
run: |
echo TODO - fix pyright errors # pyright cf_remote
pyright cf_remote
- name: Lint with pyflakes
run: |
pyflakes cf_remote
Expand Down
104 changes: 66 additions & 38 deletions cf_remote/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,11 @@
write_json,
whoami,
get_package_name,
user_error,
CFRExitError,
is_package_url,
print_progress_dot,
ChecksumError,
CFRChecksumError,
CFRUserError,
)
from cf_remote.spawn import VM, VMRequest, Providers, AWSCredentials, GCPCredentials
from cf_remote.spawn import spawn_vms, destroy_vms, dump_vms_info, get_cloud_driver
Expand Down Expand Up @@ -71,6 +72,7 @@ def run(hosts, command, users=None, sudo=False, raw=False):
continue
cmd = command
lines = lines.replace("\r", "")
fill = ""
for line in lines.split("\n"):
if raw:
print(line)
Expand All @@ -79,6 +81,7 @@ def run(hosts, command, users=None, sudo=False, raw=False):
fill = " " * (len(cmd) + 7)
cmd = None
else:
assert fill, "First iteration of loop should have set fill variable"
print("{}{}'{}'".format(host_colon, fill, line))
return errors

Expand Down Expand Up @@ -123,7 +126,9 @@ def _download_urls(urls):
paths.append(path)

if path in downloaded_paths and url not in downloaded_urls:
user_error("2 packages with the same name '%s' from different URLs" % name)
raise CFRExitError(
"2 packages with the same name '%s' from different URLs" % name
)

download_package(url, path)
downloaded_urls.append(url)
Expand All @@ -143,7 +148,7 @@ def _verify_package_urls(urls):
if is_package_url(package_url):
verified_urls.append(package_url)
else:
user_error("Wrong package URL: {}".format(package_url))
raise CFRExitError("Wrong package URL: {}".format(package_url))

return verified_urls

Expand Down Expand Up @@ -184,7 +189,7 @@ def install(
else:
try:
package, hub_package, client_package = _download_urls(packages)
except ChecksumError as ce:
except CFRChecksumError as ce:
log.error(ce)
return 1

Expand Down Expand Up @@ -288,14 +293,15 @@ def install(
def _iterate_over_packages(
tags=None, version=None, edition=None, download=False, output_dir=None
):
assert edition in ["enterprise", "community", None]
releases = Releases(edition)
print("Available releases: {}".format(releases))

release_versions = [rel.version for rel in releases.releases]
if version and version not in release_versions:
user_error("CFEngine version '%s' doesn't exist (yet)." % version)
raise CFRExitError("CFEngine version '%s' doesn't exist (yet)." % version)

if not version:
if tags and not version:
for tag in tags:
if tag in release_versions:
version = tag
Expand All @@ -305,10 +311,11 @@ def _iterate_over_packages(
release = releases.default
if version:
release = releases.pick_version(version)
if not release:
raise CFRExitError("Failed to find a release for version '%s'" % version)
print("Using {}:".format(release))
log.debug("Looking for a release based on host tags: {}".format(tags))
artifacts = release.find(tags)

if len(artifacts) == 0:
print("No suitable packages found")
else:
Expand All @@ -318,14 +325,14 @@ def _iterate_over_packages(
package_path = download_package(
artifact.url, checksum=artifact.checksum
)
except ChecksumError as ce:
except CFRChecksumError as ce:
log.error(ce)
return 1
if output_dir:
output_dir = os.path.abspath(os.path.expanduser(output_dir))
parent = os.path.dirname(output_dir)
if not os.path.exists(parent):
user_error(
raise CFRExitError(
"'{}' doesn't exist. Make sure this path is correct and exists.".format(
parent
)
Expand Down Expand Up @@ -373,16 +380,17 @@ def spawn(
public_ip=True,
extend_group=False,
):
creds_data = None
if os.path.exists(CLOUD_CONFIG_FPATH):
creds_data = read_json(CLOUD_CONFIG_FPATH)
else:
print("Cloud configuration not found at %s" % CLOUD_CONFIG_FPATH)
return 1
if not creds_data:
raise CFRUserError("Cloud configuration not found at %s" % CLOUD_CONFIG_FPATH)

vms_info = None
if os.path.exists(CLOUD_STATE_FPATH):
vms_info = read_json(CLOUD_STATE_FPATH)
else:
vms_info = dict()
if not vms_info:
vms_info = {}

group_key = "@%s" % group_name
group_exists = group_key in vms_info
Expand Down Expand Up @@ -523,6 +531,8 @@ def destroy(group_name=None):
return 1

vms_info = read_json(CLOUD_STATE_FPATH)
if not vms_info:
raise CFRUserError("No saved VMs found in '{}'".format(CLOUD_STATE_FPATH))

to_destroy = []
if group_name:
Expand All @@ -541,16 +551,23 @@ def destroy(group_name=None):

region = vms_info[group_name]["meta"]["region"]
provider = vms_info[group_name]["meta"]["provider"]
if provider not in ["aws", "gcp"]:
raise CFRUserError(
"Unsupported provider '{}' encountered in '{}', only aws / gcp is supported".format(
provider, CLOUD_STATE_FPATH
)
)

driver = None
if provider == "aws":
if aws_creds is None:
user_error("Missing/incomplete AWS credentials")
return 1
raise CFRExitError("Missing/incomplete AWS credentials")
driver = get_cloud_driver(Providers.AWS, aws_creds, region)
if provider == "gcp":
if gcp_creds is None:
user_error("Missing/incomplete GCP credentials")
return 1
raise CFRExitError("Missing/incomplete GCP credentials")
driver = get_cloud_driver(Providers.GCP, gcp_creds, region)
assert driver is not None

nodes = driver.list_nodes()
for name, vm_info in vms_info[group_name].items():
Expand All @@ -572,16 +589,23 @@ def destroy(group_name=None):

region = vms_info[group_name]["meta"]["region"]
provider = vms_info[group_name]["meta"]["provider"]
if provider not in ["aws", "gcp"]:
raise CFRUserError(
"Unsupported provider '{}' encountered in '{}', only aws / gcp is supported".format(
provider, CLOUD_STATE_FPATH
)
)

driver = None
if provider == "aws":
if aws_creds is None:
user_error("Missing/incomplete AWS credentials")
return 1
raise CFRExitError("Missing/incomplete AWS credentials")
driver = get_cloud_driver(Providers.AWS, aws_creds, region)
if provider == "gcp":
if gcp_creds is None:
user_error("Missing/incomplete GCP credentials")
return 1
raise CFRExitError("Missing/incomplete GCP credentials")
driver = get_cloud_driver(Providers.GCP, gcp_creds, region)
assert driver is not None

nodes = driver.list_nodes()
for name, vm_info in vms_info[group_name].items():
Expand Down Expand Up @@ -673,11 +697,15 @@ def save(name, hosts, role):


def _ansible_inventory():
if not os.path.exists(CLOUD_STATE_FPATH):
print("No saved cloud state info")
return 1

vms_info = read_json(CLOUD_STATE_FPATH)
vms_info = None
if os.path.exists(CLOUD_STATE_FPATH):
vms_info = read_json(CLOUD_STATE_FPATH)

if not vms_info:
raise CFRUserError(
"No saved cloud state info in '{}'".format(CLOUD_STATE_FPATH)
)
all_lines = []
hub_lines = []
client_lines = []
Expand Down Expand Up @@ -851,7 +879,7 @@ def deploy(hubs, masterfiles):
print("Found saved/spawned hubs: " + ", ".join(hubs))

if not hubs:
user_error(
raise CFRExitError(
"No hub to deploy to (Specify with --hub or use spawn/save commands to add to cf-remote)"
)

Expand All @@ -866,7 +894,7 @@ def deploy(hubs, masterfiles):
urls = [masterfiles]
try:
paths = _download_urls(urls)
except ChecksumError as ce:
except CFRChecksumError as ce:
log.error(ce)
return 1
assert len(paths) == 1
Expand All @@ -876,7 +904,7 @@ def deploy(hubs, masterfiles):
if not masterfiles:
masterfiles = "."
if not (os.path.isfile("promises.cf") or os.path.isfile("promises.cf.in")):
user_error("No cfbs or masterfiles policy set found")
raise CFRExitError("No cfbs or masterfiles policy set found")

masterfiles = os.path.abspath(os.path.expanduser(masterfiles))
print("Found masterfiles policy set: '{}'".format(masterfiles))
Expand Down Expand Up @@ -937,23 +965,23 @@ def deploy(hubs, masterfiles):


def agent(hosts, bootstrap=None):

if len(bootstrap) > 1:
user_error(
if bootstrap and len(bootstrap) > 1:
raise CFRExitError(
"Cannot boostrap {} to {}. Cannot bootstrap to more than one host.".format(
hosts, bootstrap
)
)

hub_host = bootstrap[0]

for host in hosts:
data = get_info(host)

if not data["agent_location"]:
user_error("CFEngine not installed on {}".format(host))
raise CFRExitError("CFEngine not installed on {}".format(host))

command = "{}".format(data["agent_location"])
if bootstrap:
command += "--bootstrap {}".format(bootstrap[0])

command = "{} --bootstrap {}".format(data["agent_location"], hub_host)
output = run_command(host, command, sudo=True)
if output:
print(output)
Expand All @@ -965,7 +993,7 @@ def connect_cmd(hosts):
assert hosts and len(hosts) >= 1 # Ensured by argument parser

if len(hosts) > 1:
user_error("You can only connect to one host at a time")
raise CFRExitError("You can only connect to one host at a time")

print("Opening a SSH command shell...")
r = subprocess.run(["ssh", hosts[0]])
Expand Down
Loading