Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions runpod/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

STOP_EVENT = threading.Event()


# --------------------------- runpod.toml Defaults --------------------------- #
BASE_DOCKER_IMAGE = 'runpod/base:0.4.0-cuda{cuda_version}'
GPU_TYPES = [
Expand Down
5 changes: 4 additions & 1 deletion runpod/cli/groups/ssh/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from runpod.api.ctl_commands import get_user, update_user_settings

SSH_FILES = os.path.expanduser('~/.runpod/ssh')
os.makedirs(os.path.join(SSH_FILES), exist_ok=True)


def get_ssh_key_fingerprint(public_key):
'''
Expand Down Expand Up @@ -53,6 +53,7 @@ def get_user_pub_keys():

return key_list


def generate_ssh_key_pair(filename):
"""
Generate an RSA SSH key pair and save it to disk.
Expand All @@ -61,6 +62,8 @@ def generate_ssh_key_pair(filename):
- filename (str): The base filename to save the key pair.
The public key will have '.pub' appended to it.
"""
os.makedirs(os.path.join(SSH_FILES), exist_ok=True)

# Generate private key
private_key = paramiko.RSAKey.generate(bits=2048)
private_key.write_private_key_file(os.path.join(SSH_FILES, filename))
Expand Down
4 changes: 2 additions & 2 deletions runpod/cli/utils/rp_userspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def find_ssh_key_file(pod_ip, pod_port):
ssh = paramiko.SSHClient()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())

for file in os.listdir(SSH_KEY_PATH):
for file in os.listdir(SSH_KEY_PATH):
file_path = os.path.join(SSH_KEY_PATH, file)

if not os.path.isfile(file_path) or file.endswith('.pub'):
Expand All @@ -28,7 +28,7 @@ def find_ssh_key_file(pod_ip, pod_port):
ssh.close()
print(f"Connected to pod {pod_ip}:{pod_port} using key {file}")
return file_path
except Exception as err: # pylint: disable=broad-except
except Exception as err: # pylint: disable=broad-except
print(f"An error occurred with key {file}: {err}")

print("Failed to connect using all available keys.")
Expand Down
9 changes: 6 additions & 3 deletions tests/test_cli/test_cli_groups/test_ssh_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
generate_ssh_key_pair, add_ssh_key
)


class TestSSHFunctions(unittest.TestCase):
""" Tests for the SSH functions """

Expand Down Expand Up @@ -58,9 +59,11 @@ def test_generate_ssh_key_pair(self, mock_generate, mock_path_join):
mock_generate.return_value.get_base64.return_value = "ABCDE12345"
mock_path_join.return_value = "/path/to/private_key"

with patch("builtins.open", mock_open()) as mock_file, \
patch("runpod.cli.groups.ssh.functions.os.chmod") as mock_chmod, \
patch("runpod.cli.groups.ssh.functions.add_ssh_key") as mock_add_key:
with patch("os.mkdir") as mock_mkdir, \
patch("builtins.open", mock_open()) as mock_file, \
patch("runpod.cli.groups.ssh.functions.os.chmod") as mock_chmod, \
patch("runpod.cli.groups.ssh.functions.add_ssh_key") as mock_add_key:
mock_mkdir.return_value = None
mock_file.return_value.write.return_value = None
private_key, public_key = generate_ssh_key_pair("test_key")
self.assertEqual(public_key, "ssh-rsa ABCDE12345 test_key")
Expand Down