Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 105 additions & 5 deletions vast.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import textwrap
from pathlib import Path
import warnings
import platform

ARGS = None
TABCOMPLETE = False
Expand Down Expand Up @@ -1464,19 +1465,118 @@ def create__env_var(args):
else:
print(f"Failed to create environment variable: {result.get('msg', 'Unknown error')}")

def get_ssh_key_paths():
if platform.system() == "Windows":
base = os.environ["USERPROFILE"]
else:
base = os.path.expanduser("~")
key_path = os.path.join(base, ".ssh", "id_rsa")
pub_key_path = key_path + ".pub"
return key_path, pub_key_path


def ensure_ssh_key_exists(key_path, pub_key_path):
if os.path.exists(pub_key_path):
print("SSH key pair already exists, using the existing key.")
return True

print("Generating a new RSA SSH key pair...")
try:
subprocess.run(["ssh-keygen", "-t", "rsa", "-f", key_path, "-q", "-N", ""], check=True)
return True
except FileNotFoundError:
return False
except subprocess.CalledProcessError:
print("Error occurred while generating the SSH key. Please check your setup or generate manually.")
return False


def install_openssh_client_linux():
try:
print("Attempting to install OpenSSH client on Linux...")
subprocess.run(["sudo", "apt-get", "update"], check=True)
subprocess.run(["sudo", "apt-get", "install", "-y", "openssh-client"], check=True)
return True
except subprocess.CalledProcessError:
print("Failed to install openssh-client. Please install it manually.")
return False


def add_key_to_ssh_agent(key_path):
try:
subprocess.run(["ssh-add", key_path], check=True)
subprocess.run(["ssh-add", "-l"], check=True)
return True
except subprocess.CalledProcessError:
print("Unable to add SSH key to the agent. Make sure the ssh-agent is running.")
return False


def read_public_key(pub_key_path):
try:
with open(pub_key_path, "r") as f:
return f.read().strip()
except OSError:
print(f"Unable to read the public key from {pub_key_path}.")
return None


def generate_ssh_key_pair():
system = platform.system()
key_path, pub_key_path = get_ssh_key_paths()

os.makedirs(os.path.dirname(key_path), exist_ok=True)

if system == "Windows":
if not ensure_ssh_key_exists(key_path, pub_key_path):
print("'ssh-keygen' not found. Please install the OpenSSH Client and ensure it's on your PATH.")
return None
return read_public_key(pub_key_path)

elif system in ["Linux", "Darwin"]:
if not ensure_ssh_key_exists(key_path, pub_key_path):
if system == "Linux":
if install_openssh_client_linux():
if not ensure_ssh_key_exists(key_path, pub_key_path):
return None
else:
return None
else:
print("'ssh-keygen' not found. Please install it manually.")
return None

add_key_to_ssh_agent(key_path)
return read_public_key(pub_key_path)

else:
print("Unsupported platform. Only Linux, macOS, and Windows are supported.")
return None

@parser.command(
argument("ssh_key", help="add the public key of your ssh key to your account (form the .pub file)", type=str),
argument("ssh_key", help="add the public key of your ssh key to your account (or use 'auto' to generate one automatically on Linux/MacOS)", type=str, default="auto"),
usage="vastai create ssh-key ssh_key",
help="Create a new ssh-key",
epilog=deindent("""
Use this command to create a new ssh key for your account.
All ssh keys are stored in your account and can be used to connect to instances they've been added to
All ssh keys should be added in rsa format
Use this command to create a new ssh key for your account.
All ssh keys are stored in your account and can be used to connect to instances they've been added to.
All ssh keys should be added in RSA format.

Quickstart (for Linux/MacOS with 'auto'):
1. Generates an RSA key pair (if needed) using: ssh-keygen -t rsa
2. Loads the key into the SSH agent using: ssh-add; ssh-add -l
3. Reads your public key from ~/.ssh/id_rsa.pub and uses it for the account.
""")
)
def create__ssh_key(args):
if args.ssh_key == "auto":
public_key = generate_ssh_key_pair()
if public_key is None:
return
else:
public_key = args.ssh_key

url = apiurl(args, "/ssh/")
r = http_post(args, url, headers=headers, json={"ssh_key": args.ssh_key})
r = http_post(args, url, headers=headers, json={"ssh_key": public_key})
r.raise_for_status()
print("ssh-key created {}".format(r.json()))

Expand Down