diff --git a/runpod/cli/groups/pod/commands.py b/runpod/cli/groups/pod/commands.py index 34b2ca1a..680f3f66 100644 --- a/runpod/cli/groups/pod/commands.py +++ b/runpod/cli/groups/pod/commands.py @@ -2,6 +2,9 @@ RunPod | CLI | Pod | Commands """ +import os +import tempfile +import uuid import click from prettytable import PrettyTable @@ -71,3 +74,173 @@ def connect_to_pod(pod_id): click.echo(f"Connecting to pod {pod_id}...") ssh = ssh_cmd.SSHConnection(pod_id) ssh.launch_terminal() + + +@pod_cli.command("sync") +@click.argument("source_pod_id") +@click.argument("dest_pod_id") +@click.argument("source_workspace", default="/workspace") +@click.argument("dest_workspace", default="/workspace") +def sync_pods(source_pod_id, dest_pod_id, source_workspace, dest_workspace): + """ + Sync data between two pods via SSH. + + Transfers files from source_pod_id:source_workspace to dest_pod_id:dest_workspace. + The workspace will be zipped and transferred to avoid file name conflicts. + + 📋 PREREQUISITES: + + 1. SSH Key Setup: + • You must have an SSH key configured in your RunPod account + • If you don't have one, create it with: runpod ssh add-key + • List your keys with: runpod ssh list-keys + + 2. Pod Configuration: + • Both pods must have SSH access enabled + • For running pods using official RunPod templates, you may need to add + your public key to the PUBLIC_KEY environment variable and restart the pod + + ⚠️ IMPORTANT NOTES: + + • If a pod was started before adding your SSH key, you'll need to: + 1. Stop the pod + 2. Add PUBLIC_KEY environment variable with your public key + 3. Restart the pod + + • The sync creates a unique folder (sync_XXXXXXXX) in the destination to avoid + file conflicts + + 📖 EXAMPLES: + + Basic sync (uses /workspace as default): + runpod pod sync pod1 pod2 + + Custom paths: + runpod pod sync pod1 pod2 /workspace/data /workspace/backup + + Different directories: + runpod pod sync pod1 pod2 /home/user/files /workspace/imported + """ + + # Check if user has SSH keys configured + try: + from ...groups.ssh.functions import get_user_pub_keys + user_keys = get_user_pub_keys() + if not user_keys: + click.echo("❌ No SSH keys found in your RunPod account!") + click.echo("") + click.echo("🔑 To create an SSH key, run:") + click.echo(" runpod ssh add-key") + click.echo("") + click.echo("📖 For more help, see:") + click.echo(" runpod ssh add-key --help") + return + else: + click.echo(f"✅ Found {len(user_keys)} SSH key(s) in your account") + except Exception as e: + click.echo(f"⚠️ Warning: Could not verify SSH keys: {str(e)}") + click.echo("Continuing with sync attempt...") + + click.echo(f"🔄 Syncing from {source_pod_id}:{source_workspace} to {dest_pod_id}:{dest_workspace}") + + # Generate unique folder name to avoid conflicts + transfer_id = str(uuid.uuid4())[:8] + temp_zip_name = f"sync_{transfer_id}.tar.gz" + dest_folder = f"sync_{transfer_id}" + + try: + # Connect to source pod + click.echo(f"📡 Connecting to source pod {source_pod_id}...") + with ssh_cmd.SSHConnection(source_pod_id) as source_ssh: + + # Count files in source directory + click.echo(f"📊 Counting files in {source_workspace}...") + _, stdout, _ = source_ssh.ssh.exec_command(f"find {source_workspace} -type f | wc -l") + file_count = stdout.read().decode().strip() + click.echo(f"📁 Found {file_count} files in source workspace") + + # Check if source directory exists + _, stdout, stderr = source_ssh.ssh.exec_command(f"test -d {source_workspace} && echo 'exists' || echo 'not_found'") + result = stdout.read().decode().strip() + if result != 'exists': + click.echo(f"❌ Error: Source workspace {source_workspace} does not exist on pod {source_pod_id}") + return + + # Create tar.gz archive of the workspace + click.echo(f"📦 Creating archive of {source_workspace}...") + archive_path = f"/tmp/{temp_zip_name}" + tar_command = f"cd {os.path.dirname(source_workspace)} && tar -czf {archive_path} {os.path.basename(source_workspace)}" + source_ssh.run_commands([tar_command]) + + # Check if archive was created successfully + _, stdout, _ = source_ssh.ssh.exec_command(f"test -f {archive_path} && echo 'created' || echo 'failed'") + archive_result = stdout.read().decode().strip() + if archive_result != 'created': + click.echo(f"❌ Error: Failed to create archive on source pod") + return + + # Get archive size for progress indication + _, stdout, _ = source_ssh.ssh.exec_command(f"du -h {archive_path} | cut -f1") + archive_size = stdout.read().decode().strip() + click.echo(f"✅ Archive created successfully ({archive_size})") + + # Download archive to local temp file + click.echo("⬇️ Downloading archive to local machine...") + with tempfile.NamedTemporaryFile(delete=False, suffix=".tar.gz") as temp_file: + local_temp_path = temp_file.name + source_ssh.get_file(archive_path, local_temp_path) + + # Clean up archive on source pod + source_ssh.run_commands([f"rm -f {archive_path}"]) + + # Connect to destination pod + click.echo(f"📡 Connecting to destination pod {dest_pod_id}...") + with ssh_cmd.SSHConnection(dest_pod_id) as dest_ssh: + + # Check if destination directory exists, create if not + click.echo(f"📂 Preparing destination workspace {dest_workspace}...") + dest_ssh.run_commands([f"mkdir -p {dest_workspace}"]) + + # Upload archive to destination pod + click.echo("⬆️ Uploading archive to destination pod...") + dest_archive_path = f"/tmp/{temp_zip_name}" + dest_ssh.put_file(local_temp_path, dest_archive_path) + + # Extract archive in destination workspace + click.echo(f"📦 Extracting archive to {dest_workspace}/{dest_folder}...") + extract_command = f"cd {dest_workspace} && mkdir -p {dest_folder} && cd {dest_folder} && tar -xzf {dest_archive_path} --strip-components=1" + dest_ssh.run_commands([extract_command]) + + # Verify extraction and count files + _, stdout, _ = dest_ssh.ssh.exec_command(f"find {dest_workspace}/{dest_folder} -type f | wc -l") + dest_file_count = stdout.read().decode().strip() + click.echo(f"📁 Extracted {dest_file_count} files to destination") + + # Clean up archive on destination pod + dest_ssh.run_commands([f"rm -f {dest_archive_path}"]) + + # Show final destination path + click.echo("") + click.echo("🎉 Sync completed successfully!") + click.echo(f"📊 Files transferred: {file_count}") + click.echo(f"📍 Destination location: {dest_pod_id}:{dest_workspace}/{dest_folder}") + click.echo("") + click.echo("💡 To access the synced files:") + click.echo(f" runpod ssh {dest_pod_id}") + click.echo(f" cd {dest_workspace}/{dest_folder}") + + except Exception as e: + click.echo(f"❌ Error during sync: {str(e)}") + click.echo("") + click.echo("🔧 Troubleshooting tips:") + click.echo("• Ensure both pods have SSH access enabled") + click.echo("• Check that your SSH key is added to your RunPod account: runpod ssh list-keys") + click.echo("• For running pods, you may need to add PUBLIC_KEY env var and restart") + click.echo("• Verify the source and destination paths exist") + finally: + # Clean up local temp file + try: + if 'local_temp_path' in locals(): + os.unlink(local_temp_path) + except: + pass diff --git a/runpod/serverless/utils/rp_debugger.py b/runpod/serverless/utils/rp_debugger.py index e99f7cd4..6b37f196 100644 --- a/runpod/serverless/utils/rp_debugger.py +++ b/runpod/serverless/utils/rp_debugger.py @@ -5,6 +5,7 @@ """ import datetime +from datetime import timezone import platform import time @@ -86,7 +87,7 @@ def start(self, name): index = self.name_lookup[name] self.checkpoints[index]["start"] = time.perf_counter() self.checkpoints[index]["start_utc"] = ( - datetime.datetime.utcnow().isoformat() + "Z" + datetime.datetime.now(timezone.utc).isoformat() + "Z" ) def stop(self, name): @@ -103,7 +104,7 @@ def stop(self, name): self.checkpoints[index]["end"] = time.perf_counter() self.checkpoints[index]["stop_utc"] = ( - datetime.datetime.utcnow().isoformat() + "Z" + datetime.datetime.now(timezone.utc).isoformat() + "Z" ) def get_checkpoints(self): diff --git a/tests/test_cli/test_cli_groups/test_pod_commands.py b/tests/test_cli/test_cli_groups/test_pod_commands.py index 92cb4076..a50e4561 100644 --- a/tests/test_cli/test_cli_groups/test_pod_commands.py +++ b/tests/test_cli/test_cli_groups/test_pod_commands.py @@ -1,7 +1,7 @@ """ Test CLI pod commands """ import unittest -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, patch, mock_open from click.testing import CliRunner from prettytable import PrettyTable @@ -96,3 +96,225 @@ def test_connect_to_pod(self, mock_ssh_connection, mock_echo): mock_echo.assert_called_once_with(f"Connecting to pod {pod_id}...") mock_ssh_connection.assert_called_once_with(pod_id) mock_ssh.launch_terminal.assert_called_once_with() + + @patch("runpod.cli.groups.pod.commands.os.unlink") + @patch("runpod.cli.groups.pod.commands.tempfile.NamedTemporaryFile") + @patch("runpod.cli.groups.pod.commands.uuid.uuid4") + @patch("runpod.cli.groups.pod.commands.click.echo") + @patch("runpod.cli.groups.pod.commands.ssh_cmd.SSHConnection") + def test_sync_pods_success(self, mock_ssh_connection, mock_echo, mock_uuid, mock_temp_file, mock_unlink): + """ + Test sync_pods function - successful sync + """ + # Setup mocks + mock_uuid.return_value = MagicMock() + mock_uuid.return_value.__str__ = MagicMock(return_value="12345678-1234-1234-1234-123456789012") + + mock_temp_file.return_value.__enter__.return_value.name = "/tmp/test_archive.tar.gz" + + # Mock SSH connections + mock_source_ssh = MagicMock() + mock_dest_ssh = MagicMock() + + # Mock SSH exec_command responses + mock_source_ssh.ssh.exec_command.side_effect = [ + (None, MagicMock(read=lambda: b"42"), None), # file count + (None, MagicMock(read=lambda: b"exists"), None), # directory exists check + (None, MagicMock(read=lambda: b"created"), None), # archive created check + (None, MagicMock(read=lambda: b"1.5M"), None), # archive size + ] + + mock_dest_ssh.ssh.exec_command.return_value = (None, MagicMock(read=lambda: b"42"), None) # dest file count + + # Configure SSH connection context manager + def ssh_side_effect(pod_id): + if pod_id == "source_pod": + mock_source_ssh.__enter__ = MagicMock(return_value=mock_source_ssh) + mock_source_ssh.__exit__ = MagicMock(return_value=None) + return mock_source_ssh + elif pod_id == "dest_pod": + mock_dest_ssh.__enter__ = MagicMock(return_value=mock_dest_ssh) + mock_dest_ssh.__exit__ = MagicMock(return_value=None) + return mock_dest_ssh + + mock_ssh_connection.side_effect = ssh_side_effect + + # Mock SSH key validation + with patch("runpod.cli.groups.ssh.functions.get_user_pub_keys") as mock_get_keys: + mock_get_keys.return_value = [{"name": "test-key", "type": "ssh-rsa", "fingerprint": "SHA256:test"}] + + runner = CliRunner() + result = runner.invoke(runpod_cli, ["pod", "sync", "source_pod", "dest_pod", "/workspace", "/workspace"]) + + assert result.exit_code == 0, result.exception + + # Verify SSH connections were created + assert mock_ssh_connection.call_count == 2 + mock_ssh_connection.assert_any_call("source_pod") + mock_ssh_connection.assert_any_call("dest_pod") + + # Verify file operations + mock_source_ssh.get_file.assert_called_once() + mock_dest_ssh.put_file.assert_called_once() + + # Verify commands were run + mock_source_ssh.run_commands.assert_called() + mock_dest_ssh.run_commands.assert_called() + + @patch("runpod.cli.groups.pod.commands.uuid.uuid4") + @patch("runpod.cli.groups.pod.commands.click.echo") + @patch("runpod.cli.groups.pod.commands.ssh_cmd.SSHConnection") + def test_sync_pods_no_ssh_keys(self, mock_ssh_connection, mock_echo, mock_uuid): + """ + Test sync_pods function - no SSH keys configured + """ + # Setup mocks + mock_uuid.return_value = MagicMock() + mock_uuid.return_value.__str__ = MagicMock(return_value="12345678-1234-1234-1234-123456789012") + + # Mock SSH key validation - no keys found + with patch("runpod.cli.groups.ssh.functions.get_user_pub_keys") as mock_get_keys: + mock_get_keys.return_value = [] # No SSH keys + + runner = CliRunner() + result = runner.invoke(runpod_cli, ["pod", "sync", "source_pod", "dest_pod"]) + + assert result.exit_code == 0, result.exception + + # Verify error message was shown + mock_echo.assert_any_call("❌ No SSH keys found in your RunPod account!") + mock_echo.assert_any_call("🔑 To create an SSH key, run:") + mock_echo.assert_any_call(" runpod ssh add-key") + + @patch("runpod.cli.groups.pod.commands.uuid.uuid4") + @patch("runpod.cli.groups.pod.commands.click.echo") + @patch("runpod.cli.groups.pod.commands.ssh_cmd.SSHConnection") + def test_sync_pods_source_not_found(self, mock_ssh_connection, mock_echo, mock_uuid): + """ + Test sync_pods function - source directory not found + """ + # Setup mocks + mock_uuid.return_value = MagicMock() + mock_uuid.return_value.__str__ = MagicMock(return_value="12345678-1234-1234-1234-123456789012") + + mock_source_ssh = MagicMock() + mock_source_ssh.__enter__ = MagicMock(return_value=mock_source_ssh) + mock_source_ssh.__exit__ = MagicMock(return_value=None) + + # Mock SSH exec_command responses - directory doesn't exist + mock_source_ssh.ssh.exec_command.side_effect = [ + (None, MagicMock(read=lambda: b"0"), None), # file count + (None, MagicMock(read=lambda: b"not_found"), None), # directory exists check + ] + + mock_ssh_connection.return_value = mock_source_ssh + + # Mock SSH key validation + with patch("runpod.cli.groups.ssh.functions.get_user_pub_keys") as mock_get_keys: + mock_get_keys.return_value = [{"name": "test-key", "type": "ssh-rsa", "fingerprint": "SHA256:test"}] + + runner = CliRunner() + result = runner.invoke(runpod_cli, ["pod", "sync", "source_pod", "dest_pod", "/nonexistent", "/workspace"]) + + assert result.exit_code == 0, result.exception + + # Verify error message was shown + mock_echo.assert_any_call("❌ Error: Source workspace /nonexistent does not exist on pod source_pod") + + @patch("runpod.cli.groups.pod.commands.uuid.uuid4") + @patch("runpod.cli.groups.pod.commands.click.echo") + @patch("runpod.cli.groups.pod.commands.ssh_cmd.SSHConnection") + def test_sync_pods_archive_creation_failed(self, mock_ssh_connection, mock_echo, mock_uuid): + """ + Test sync_pods function - archive creation failed + """ + # Setup mocks + mock_uuid.return_value = MagicMock() + mock_uuid.return_value.__str__ = MagicMock(return_value="12345678-1234-1234-1234-123456789012") + + mock_source_ssh = MagicMock() + mock_source_ssh.__enter__ = MagicMock(return_value=mock_source_ssh) + mock_source_ssh.__exit__ = MagicMock(return_value=None) + + # Mock SSH exec_command responses - archive creation fails + mock_source_ssh.ssh.exec_command.side_effect = [ + (None, MagicMock(read=lambda: b"42"), None), # file count + (None, MagicMock(read=lambda: b"exists"), None), # directory exists check + (None, MagicMock(read=lambda: b"failed"), None), # archive created check fails + ] + + mock_ssh_connection.return_value = mock_source_ssh + + # Mock SSH key validation + with patch("runpod.cli.groups.ssh.functions.get_user_pub_keys") as mock_get_keys: + mock_get_keys.return_value = [{"name": "test-key", "type": "ssh-rsa", "fingerprint": "SHA256:test"}] + + runner = CliRunner() + result = runner.invoke(runpod_cli, ["pod", "sync", "source_pod", "dest_pod", "/workspace", "/workspace"]) + + assert result.exit_code == 0, result.exception + + # Verify error message was shown + mock_echo.assert_any_call("❌ Error: Failed to create archive on source pod") + + @patch("runpod.cli.groups.pod.commands.uuid.uuid4") + @patch("runpod.cli.groups.pod.commands.click.echo") + @patch("runpod.cli.groups.pod.commands.ssh_cmd.SSHConnection") + def test_sync_pods_ssh_exception(self, mock_ssh_connection, mock_echo, mock_uuid): + """ + Test sync_pods function - SSH connection exception + """ + # Setup mocks + mock_uuid.return_value = MagicMock() + mock_uuid.return_value.__str__ = MagicMock(return_value="12345678-1234-1234-1234-123456789012") + + # Mock SSH connection to raise exception + mock_ssh_connection.side_effect = Exception("SSH connection failed") + + # Mock SSH key validation + with patch("runpod.cli.groups.ssh.functions.get_user_pub_keys") as mock_get_keys: + mock_get_keys.return_value = [{"name": "test-key", "type": "ssh-rsa", "fingerprint": "SHA256:test"}] + + runner = CliRunner() + result = runner.invoke(runpod_cli, ["pod", "sync", "source_pod", "dest_pod", "/workspace", "/workspace"]) + + assert result.exit_code == 0, result.exception + + # Verify error message was shown + mock_echo.assert_any_call("❌ Error during sync: SSH connection failed") + + @patch("runpod.cli.groups.pod.commands.uuid.uuid4") + @patch("runpod.cli.groups.pod.commands.click.echo") + @patch("runpod.cli.groups.pod.commands.ssh_cmd.SSHConnection") + def test_sync_pods_default_workspace(self, mock_ssh_connection, mock_echo, mock_uuid): + """ + Test sync_pods function - using default workspace paths + """ + # Setup mocks + mock_uuid.return_value = MagicMock() + mock_uuid.return_value.__str__ = MagicMock(return_value="12345678-1234-1234-1234-123456789012") + + mock_source_ssh = MagicMock() + mock_source_ssh.__enter__ = MagicMock(return_value=mock_source_ssh) + mock_source_ssh.__exit__ = MagicMock(return_value=None) + + # Mock SSH exec_command responses + mock_source_ssh.ssh.exec_command.side_effect = [ + (None, MagicMock(read=lambda: b"10"), None), # file count + (None, MagicMock(read=lambda: b"exists"), None), # directory exists check + ] + + mock_ssh_connection.return_value = mock_source_ssh + + # Mock SSH key validation + with patch("runpod.cli.groups.ssh.functions.get_user_pub_keys") as mock_get_keys: + mock_get_keys.return_value = [{"name": "test-key", "type": "ssh-rsa", "fingerprint": "SHA256:test"}] + + runner = CliRunner() + # Test with only pod IDs (should use /workspace as default) + result = runner.invoke(runpod_cli, ["pod", "sync", "source_pod", "dest_pod"]) + + assert result.exit_code == 0, result.exception + + # Verify the default workspace path is used + mock_echo.assert_any_call("🔄 Syncing from source_pod:/workspace to dest_pod:/workspace") diff --git a/tests/test_serverless/test_utils/test_download.py b/tests/test_serverless/test_utils/test_download.py index 10c5bcd0..f0ab6d98 100644 --- a/tests/test_serverless/test_utils/test_download.py +++ b/tests/test_serverless/test_utils/test_download.py @@ -60,7 +60,9 @@ def __enter__(self): def __exit__(self, *args): pass - if args[0] in URL_LIST: + url = args[0] + # Check if the URL matches any of the URLs in URL_LIST + if any(url.startswith(base_url) for base_url in URL_LIST): return MockResponse(b"nothing", 200, headers) return MockResponse(None, 404)