diff --git a/runpod/cli/__init__.py b/runpod/cli/__init__.py index 588ce35e..b58a3ff2 100644 --- a/runpod/cli/__init__.py +++ b/runpod/cli/__init__.py @@ -6,6 +6,7 @@ STOP_EVENT = threading.Event() + # --------------------------- runpod.toml Defaults --------------------------- # BASE_DOCKER_IMAGE = 'runpod/base:0.4.0-cuda{cuda_version}' GPU_TYPES = [ diff --git a/runpod/cli/groups/ssh/functions.py b/runpod/cli/groups/ssh/functions.py index f026d77a..c33c2154 100644 --- a/runpod/cli/groups/ssh/functions.py +++ b/runpod/cli/groups/ssh/functions.py @@ -10,7 +10,7 @@ from runpod.api.ctl_commands import get_user, update_user_settings SSH_FILES = os.path.expanduser('~/.runpod/ssh') -os.makedirs(os.path.join(SSH_FILES), exist_ok=True) + def get_ssh_key_fingerprint(public_key): ''' @@ -53,6 +53,7 @@ def get_user_pub_keys(): return key_list + def generate_ssh_key_pair(filename): """ Generate an RSA SSH key pair and save it to disk. @@ -61,6 +62,8 @@ def generate_ssh_key_pair(filename): - filename (str): The base filename to save the key pair. The public key will have '.pub' appended to it. """ + os.makedirs(os.path.join(SSH_FILES), exist_ok=True) + # Generate private key private_key = paramiko.RSAKey.generate(bits=2048) private_key.write_private_key_file(os.path.join(SSH_FILES, filename)) diff --git a/runpod/cli/utils/rp_userspace.py b/runpod/cli/utils/rp_userspace.py index da415abc..5a1308d6 100644 --- a/runpod/cli/utils/rp_userspace.py +++ b/runpod/cli/utils/rp_userspace.py @@ -17,7 +17,7 @@ def find_ssh_key_file(pod_ip, pod_port): ssh = paramiko.SSHClient() ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - for file in os.listdir(SSH_KEY_PATH): + for file in os.listdir(SSH_KEY_PATH): file_path = os.path.join(SSH_KEY_PATH, file) if not os.path.isfile(file_path) or file.endswith('.pub'): @@ -28,7 +28,7 @@ def find_ssh_key_file(pod_ip, pod_port): ssh.close() print(f"Connected to pod {pod_ip}:{pod_port} using key {file}") return file_path - except Exception as err: # pylint: disable=broad-except + except Exception as err: # pylint: disable=broad-except print(f"An error occurred with key {file}: {err}") print("Failed to connect using all available keys.") diff --git a/tests/test_cli/test_cli_groups/test_ssh_functions.py b/tests/test_cli/test_cli_groups/test_ssh_functions.py index 2daa4289..c1dcd261 100644 --- a/tests/test_cli/test_cli_groups/test_ssh_functions.py +++ b/tests/test_cli/test_cli_groups/test_ssh_functions.py @@ -8,6 +8,7 @@ generate_ssh_key_pair, add_ssh_key ) + class TestSSHFunctions(unittest.TestCase): """ Tests for the SSH functions """ @@ -58,9 +59,11 @@ def test_generate_ssh_key_pair(self, mock_generate, mock_path_join): mock_generate.return_value.get_base64.return_value = "ABCDE12345" mock_path_join.return_value = "/path/to/private_key" - with patch("builtins.open", mock_open()) as mock_file, \ - patch("runpod.cli.groups.ssh.functions.os.chmod") as mock_chmod, \ - patch("runpod.cli.groups.ssh.functions.add_ssh_key") as mock_add_key: + with patch("os.mkdir") as mock_mkdir, \ + patch("builtins.open", mock_open()) as mock_file, \ + patch("runpod.cli.groups.ssh.functions.os.chmod") as mock_chmod, \ + patch("runpod.cli.groups.ssh.functions.add_ssh_key") as mock_add_key: + mock_mkdir.return_value = None mock_file.return_value.write.return_value = None private_key, public_key = generate_ssh_key_pair("test_key") self.assertEqual(public_key, "ssh-rsa ABCDE12345 test_key")