From 185326b7fcafaf49301eb867bd2ba3d19a502a46 Mon Sep 17 00:00:00 2001 From: Justin Merrell Date: Thu, 9 Nov 2023 15:02:19 -0500 Subject: [PATCH 1/7] fix: move file creation --- runpod/__init__.py | 4 ++++ runpod/cli/groups/ssh/functions.py | 10 +++++----- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/runpod/__init__.py b/runpod/__init__.py index c084a037..a16b740e 100644 --- a/runpod/__init__.py +++ b/runpod/__init__.py @@ -20,6 +20,10 @@ # ------------------------------- Config Paths ------------------------------- # SSH_KEY_PATH = os.path.expanduser('~/.runpod/ssh') +try: + os.makedirs(os.path.join(SSH_KEY_PATH), exist_ok=True) +except OSError: + pass profile = "default" # pylint: disable=invalid-name diff --git a/runpod/cli/groups/ssh/functions.py b/runpod/cli/groups/ssh/functions.py index f026d77a..ed02322a 100644 --- a/runpod/cli/groups/ssh/functions.py +++ b/runpod/cli/groups/ssh/functions.py @@ -7,10 +7,9 @@ import hashlib import paramiko +from runpod import SSH_KEY_PATH from runpod.api.ctl_commands import get_user, update_user_settings -SSH_FILES = os.path.expanduser('~/.runpod/ssh') -os.makedirs(os.path.join(SSH_FILES), exist_ok=True) def get_ssh_key_fingerprint(public_key): ''' @@ -53,6 +52,7 @@ def get_user_pub_keys(): return key_list + def generate_ssh_key_pair(filename): """ Generate an RSA SSH key pair and save it to disk. @@ -63,13 +63,13 @@ def generate_ssh_key_pair(filename): """ # Generate private key private_key = paramiko.RSAKey.generate(bits=2048) - private_key.write_private_key_file(os.path.join(SSH_FILES, filename)) + private_key.write_private_key_file(os.path.join(SSH_KEY_PATH, filename)) # Set permissions - os.chmod(os.path.join(SSH_FILES, filename), 0o600) + os.chmod(os.path.join(SSH_KEY_PATH, filename), 0o600) # Generate public key - with open(f"{SSH_FILES}/{filename}.pub", "w", encoding="UTF-8") as public_file: + with open(f"{SSH_KEY_PATH}/{filename}.pub", "w", encoding="UTF-8") as public_file: public_key = f"{private_key.get_name()} {private_key.get_base64()} {filename}" public_file.write(public_key) From 574354c4917f7282db6532c238d702ae0ec71dd9 Mon Sep 17 00:00:00 2001 From: Justin Merrell Date: Thu, 9 Nov 2023 15:36:45 -0500 Subject: [PATCH 2/7] fix: circular import --- runpod/__init__.py | 8 -------- runpod/cli/__init__.py | 10 ++++++++++ runpod/cli/groups/ssh/functions.py | 2 +- runpod/cli/utils/rp_userspace.py | 6 +++--- 4 files changed, 14 insertions(+), 12 deletions(-) diff --git a/runpod/__init__.py b/runpod/__init__.py index a16b740e..e5238eeb 100644 --- a/runpod/__init__.py +++ b/runpod/__init__.py @@ -18,14 +18,6 @@ from .cli.groups.config.functions import set_credentials, check_credentials, get_credentials -# ------------------------------- Config Paths ------------------------------- # -SSH_KEY_PATH = os.path.expanduser('~/.runpod/ssh') -try: - os.makedirs(os.path.join(SSH_KEY_PATH), exist_ok=True) -except OSError: - pass - - profile = "default" # pylint: disable=invalid-name _credentials = get_credentials(profile) diff --git a/runpod/cli/__init__.py b/runpod/cli/__init__.py index 588ce35e..6e1d643d 100644 --- a/runpod/cli/__init__.py +++ b/runpod/cli/__init__.py @@ -1,11 +1,21 @@ ''' Allows the CLI to be imported as a module. ''' +import os import threading from .groups import config, ssh STOP_EVENT = threading.Event() + +# ------------------------------- Config Paths ------------------------------- # +SSH_KEY_PATH = os.path.expanduser('~/.runpod/ssh') +try: + os.makedirs(os.path.join(SSH_KEY_PATH), exist_ok=True) +except OSError: + pass + + # --------------------------- runpod.toml Defaults --------------------------- # BASE_DOCKER_IMAGE = 'runpod/base:0.4.0-cuda{cuda_version}' GPU_TYPES = [ diff --git a/runpod/cli/groups/ssh/functions.py b/runpod/cli/groups/ssh/functions.py index ed02322a..e8027514 100644 --- a/runpod/cli/groups/ssh/functions.py +++ b/runpod/cli/groups/ssh/functions.py @@ -7,7 +7,7 @@ import hashlib import paramiko -from runpod import SSH_KEY_PATH +from runpod.cli import SSH_KEY_PATH from runpod.api.ctl_commands import get_user, update_user_settings diff --git a/runpod/cli/utils/rp_userspace.py b/runpod/cli/utils/rp_userspace.py index da415abc..9535d10b 100644 --- a/runpod/cli/utils/rp_userspace.py +++ b/runpod/cli/utils/rp_userspace.py @@ -3,7 +3,7 @@ ''' import os import paramiko -from runpod import SSH_KEY_PATH +from runpod.cli import SSH_KEY_PATH def find_ssh_key_file(pod_ip, pod_port): @@ -17,7 +17,7 @@ def find_ssh_key_file(pod_ip, pod_port): ssh = paramiko.SSHClient() ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - for file in os.listdir(SSH_KEY_PATH): + for file in os.listdir(SSH_KEY_PATH): file_path = os.path.join(SSH_KEY_PATH, file) if not os.path.isfile(file_path) or file.endswith('.pub'): @@ -28,7 +28,7 @@ def find_ssh_key_file(pod_ip, pod_port): ssh.close() print(f"Connected to pod {pod_ip}:{pod_port} using key {file}") return file_path - except Exception as err: # pylint: disable=broad-except + except Exception as err: # pylint: disable=broad-except print(f"An error occurred with key {file}: {err}") print("Failed to connect using all available keys.") From 370256ef4b7eafb2e979691a320d3b3c926e16e8 Mon Sep 17 00:00:00 2001 From: Justin Merrell Date: Thu, 9 Nov 2023 15:42:22 -0500 Subject: [PATCH 3/7] fix: bandaid --- runpod/__init__.py | 4 ++++ runpod/cli/groups/ssh/functions.py | 14 ++++++++++---- runpod/cli/utils/rp_userspace.py | 2 +- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/runpod/__init__.py b/runpod/__init__.py index e5238eeb..c084a037 100644 --- a/runpod/__init__.py +++ b/runpod/__init__.py @@ -18,6 +18,10 @@ from .cli.groups.config.functions import set_credentials, check_credentials, get_credentials +# ------------------------------- Config Paths ------------------------------- # +SSH_KEY_PATH = os.path.expanduser('~/.runpod/ssh') + + profile = "default" # pylint: disable=invalid-name _credentials = get_credentials(profile) diff --git a/runpod/cli/groups/ssh/functions.py b/runpod/cli/groups/ssh/functions.py index e8027514..dc4481e2 100644 --- a/runpod/cli/groups/ssh/functions.py +++ b/runpod/cli/groups/ssh/functions.py @@ -7,9 +7,15 @@ import hashlib import paramiko -from runpod.cli import SSH_KEY_PATH from runpod.api.ctl_commands import get_user, update_user_settings +SSH_FILES = os.path.expanduser('~/.runpod/ssh') + +try: + os.makedirs(os.path.join(SSH_FILES), exist_ok=True) +except OSError: + pass + def get_ssh_key_fingerprint(public_key): ''' @@ -63,13 +69,13 @@ def generate_ssh_key_pair(filename): """ # Generate private key private_key = paramiko.RSAKey.generate(bits=2048) - private_key.write_private_key_file(os.path.join(SSH_KEY_PATH, filename)) + private_key.write_private_key_file(os.path.join(SSH_FILES, filename)) # Set permissions - os.chmod(os.path.join(SSH_KEY_PATH, filename), 0o600) + os.chmod(os.path.join(SSH_FILES, filename), 0o600) # Generate public key - with open(f"{SSH_KEY_PATH}/{filename}.pub", "w", encoding="UTF-8") as public_file: + with open(f"{SSH_FILES}/{filename}.pub", "w", encoding="UTF-8") as public_file: public_key = f"{private_key.get_name()} {private_key.get_base64()} {filename}" public_file.write(public_key) diff --git a/runpod/cli/utils/rp_userspace.py b/runpod/cli/utils/rp_userspace.py index 9535d10b..5a1308d6 100644 --- a/runpod/cli/utils/rp_userspace.py +++ b/runpod/cli/utils/rp_userspace.py @@ -3,7 +3,7 @@ ''' import os import paramiko -from runpod.cli import SSH_KEY_PATH +from runpod import SSH_KEY_PATH def find_ssh_key_file(pod_ip, pod_port): From e283a4291385f67e928152e1106e1a460b5b68ee Mon Sep 17 00:00:00 2001 From: Justin Merrell Date: Thu, 9 Nov 2023 15:49:33 -0500 Subject: [PATCH 4/7] Update __init__.py --- runpod/cli/__init__.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/runpod/cli/__init__.py b/runpod/cli/__init__.py index 6e1d643d..b58a3ff2 100644 --- a/runpod/cli/__init__.py +++ b/runpod/cli/__init__.py @@ -1,6 +1,5 @@ ''' Allows the CLI to be imported as a module. ''' -import os import threading from .groups import config, ssh @@ -8,14 +7,6 @@ STOP_EVENT = threading.Event() -# ------------------------------- Config Paths ------------------------------- # -SSH_KEY_PATH = os.path.expanduser('~/.runpod/ssh') -try: - os.makedirs(os.path.join(SSH_KEY_PATH), exist_ok=True) -except OSError: - pass - - # --------------------------- runpod.toml Defaults --------------------------- # BASE_DOCKER_IMAGE = 'runpod/base:0.4.0-cuda{cuda_version}' GPU_TYPES = [ From 71340e61ac1b014a243a2fdde5bdee729b7d20d7 Mon Sep 17 00:00:00 2001 From: Justin Merrell Date: Thu, 9 Nov 2023 16:05:52 -0500 Subject: [PATCH 5/7] Update functions.py --- runpod/cli/groups/ssh/functions.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/runpod/cli/groups/ssh/functions.py b/runpod/cli/groups/ssh/functions.py index dc4481e2..380ea812 100644 --- a/runpod/cli/groups/ssh/functions.py +++ b/runpod/cli/groups/ssh/functions.py @@ -11,11 +11,6 @@ SSH_FILES = os.path.expanduser('~/.runpod/ssh') -try: - os.makedirs(os.path.join(SSH_FILES), exist_ok=True) -except OSError: - pass - def get_ssh_key_fingerprint(public_key): ''' @@ -67,6 +62,9 @@ def generate_ssh_key_pair(filename): - filename (str): The base filename to save the key pair. The public key will have '.pub' appended to it. """ + # Create the SSH directory if it doesn't exist + os.makedirs(os.path.join(SSH_FILES), exist_ok=True) + # Generate private key private_key = paramiko.RSAKey.generate(bits=2048) private_key.write_private_key_file(os.path.join(SSH_FILES, filename)) From 331a0573816e19ea9ea8d9200e4dc545bd70174d Mon Sep 17 00:00:00 2001 From: Justin Merrell Date: Thu, 9 Nov 2023 18:41:27 -0500 Subject: [PATCH 6/7] fix: permission error --- runpod/cli/groups/ssh/functions.py | 1 - tests/test_cli/test_cli_groups/test_ssh_functions.py | 10 +++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/runpod/cli/groups/ssh/functions.py b/runpod/cli/groups/ssh/functions.py index 380ea812..c33c2154 100644 --- a/runpod/cli/groups/ssh/functions.py +++ b/runpod/cli/groups/ssh/functions.py @@ -62,7 +62,6 @@ def generate_ssh_key_pair(filename): - filename (str): The base filename to save the key pair. The public key will have '.pub' appended to it. """ - # Create the SSH directory if it doesn't exist os.makedirs(os.path.join(SSH_FILES), exist_ok=True) # Generate private key diff --git a/tests/test_cli/test_cli_groups/test_ssh_functions.py b/tests/test_cli/test_cli_groups/test_ssh_functions.py index 2daa4289..a60f4303 100644 --- a/tests/test_cli/test_cli_groups/test_ssh_functions.py +++ b/tests/test_cli/test_cli_groups/test_ssh_functions.py @@ -2,12 +2,14 @@ import base64 import unittest +from unittest import mock from unittest.mock import patch, mock_open from runpod.cli.groups.ssh.functions import ( get_ssh_key_fingerprint, get_user_pub_keys, generate_ssh_key_pair, add_ssh_key ) + class TestSSHFunctions(unittest.TestCase): """ Tests for the SSH functions """ @@ -58,9 +60,11 @@ def test_generate_ssh_key_pair(self, mock_generate, mock_path_join): mock_generate.return_value.get_base64.return_value = "ABCDE12345" mock_path_join.return_value = "/path/to/private_key" - with patch("builtins.open", mock_open()) as mock_file, \ - patch("runpod.cli.groups.ssh.functions.os.chmod") as mock_chmod, \ - patch("runpod.cli.groups.ssh.functions.add_ssh_key") as mock_add_key: + with patch("os.mkdir") as mock_mkdir, \ + patch("builtins.open", mock_open()) as mock_file, \ + patch("runpod.cli.groups.ssh.functions.os.chmod") as mock_chmod, \ + patch("runpod.cli.groups.ssh.functions.add_ssh_key") as mock_add_key: + mock_mkdir.return_value = None mock_file.return_value.write.return_value = None private_key, public_key = generate_ssh_key_pair("test_key") self.assertEqual(public_key, "ssh-rsa ABCDE12345 test_key") From ec3d5ebc860839368b118d65e33d742369be3078 Mon Sep 17 00:00:00 2001 From: Justin Merrell Date: Thu, 9 Nov 2023 18:44:21 -0500 Subject: [PATCH 7/7] Update test_ssh_functions.py --- tests/test_cli/test_cli_groups/test_ssh_functions.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_cli/test_cli_groups/test_ssh_functions.py b/tests/test_cli/test_cli_groups/test_ssh_functions.py index a60f4303..c1dcd261 100644 --- a/tests/test_cli/test_cli_groups/test_ssh_functions.py +++ b/tests/test_cli/test_cli_groups/test_ssh_functions.py @@ -2,7 +2,6 @@ import base64 import unittest -from unittest import mock from unittest.mock import patch, mock_open from runpod.cli.groups.ssh.functions import ( get_ssh_key_fingerprint, get_user_pub_keys,