diff --git a/src/sagemaker/hyperpod/common/cli_decorators.py b/src/sagemaker/hyperpod/common/cli_decorators.py index 50642684..6768aa7c 100644 --- a/src/sagemaker/hyperpod/common/cli_decorators.py +++ b/src/sagemaker/hyperpod/common/cli_decorators.py @@ -7,7 +7,11 @@ import click import functools import logging +import subprocess +import os from kubernetes.client.exceptions import ApiException +from kubernetes import config +from kubernetes.config.config_exception import ConfigException logger = logging.getLogger(__name__) @@ -761,6 +765,210 @@ def _check_resources_exist(raw_resource_type: str, namespace: str) -> bool: logger.debug(f"Failed to check resource existence for {raw_resource_type}: {e}") return None + +def _check_kubernetes_connectivity() -> tuple[bool, str]: + """ + Check if we can connect to Kubernetes cluster. + Returns (is_connected, error_message). + """ + try: + # Try to load kubeconfig and make a simple API call + config.load_kube_config() + from kubernetes import client + + # Try to get cluster version - this is a lightweight call that requires authentication + version_api = client.VersionApi() + version_api.get_code() + + return True, "" + except ConfigException as e: + if "No configuration found" in str(e): + return False, "no_config" + elif "Invalid kube-config" in str(e): + return False, "invalid_config" + else: + return False, f"config_error: {str(e)}" + except ApiException as e: + if e.status == 401: + return False, "unauthorized" + elif e.status == 403: + return False, "forbidden" + else: + return False, f"api_error: {e.status} {e.reason}" + except Exception as e: + error_str = str(e).lower() + if "unauthorized" in error_str or "401" in error_str: + return False, "unauthorized" + elif "forbidden" in error_str or "403" in error_str: + return False, "forbidden" + elif "connection" in error_str or "timeout" in error_str: + return False, "connection_error" + else: + return False, f"unknown_error: {str(e)}" + +def _check_aws_credentials() -> bool: + """ + Check if AWS credentials are available and valid. + """ + try: + import boto3 + from botocore.exceptions import NoCredentialsError, PartialCredentialsError + + # Try to get caller identity + sts = boto3.client('sts') + sts.get_caller_identity() + return True + except (NoCredentialsError, PartialCredentialsError): + return False + except Exception as e: + logger.debug(f"AWS credentials check failed: {e}") + return False + +def _get_current_kubernetes_context() -> str: + """ + Get current Kubernetes context name. + """ + try: + contexts, active_context = config.list_kube_config_contexts() + if active_context: + return active_context.get('name', 'unknown') + return 'none' + except Exception as e: + logger.debug(f"Failed to get current context: {e}") + return 'unknown' + +def _generate_kubernetes_auth_error_message(error_type: str) -> str: + """ + Generate helpful error message for Kubernetes authentication issues. + """ + if error_type == "no_config": + return ( + "❌ Kubernetes configuration not found.\n" + "No kubeconfig file found. Please ensure you have:\n" + "1. A valid kubeconfig file at ~/.kube/config, or\n" + "2. Set the KUBECONFIG environment variable\n\n" + "To configure cluster access:\n" + " hyp set-cluster-context --region \n\n" + "💡 This will set up the necessary Kubernetes configuration for your HyperPod cluster." + ) + + elif error_type == "invalid_config": + return ( + "❌ Invalid Kubernetes configuration.\n" + "Your kubeconfig file appears to be corrupted or invalid.\n\n" + "To fix this:\n" + "1. Check your kubeconfig file at ~/.kube/config\n" + "2. Reconfigure cluster access:\n" + " hyp set-cluster-context --region \n\n" + "💡 This will refresh your cluster configuration with the correct settings." + ) + + elif error_type == "unauthorized": + current_context = _get_current_kubernetes_context() + aws_creds_valid = _check_aws_credentials() + + message = ( + "❌ Kubernetes authentication failed (401 Unauthorized).\n" + f"Current context: {current_context}\n\n" + "This usually means your credentials have expired or are invalid.\n\n" + ) + + if not aws_creds_valid: + message += ( + "🔍 AWS credentials issue detected:\n" + "Your AWS credentials appear to be missing or invalid.\n\n" + "To fix this:\n" + "1. Check your AWS credentials:\n" + " aws sts get-caller-identity\n" + "2. If expired, refresh your AWS credentials\n" + "3. Then reconfigure cluster access:\n" + " hyp set-cluster-context --region \n\n" + "💡 Make sure your AWS credentials have the necessary EKS permissions." + ) + else: + message += ( + "To fix this:\n" + "1. Reconfigure cluster access:\n" + " hyp set-cluster-context --region \n" + "2. Try your HyperPod command again\n\n" + "💡 This will refresh your authentication with the cluster." + ) + + return message + + elif error_type == "forbidden": + return ( + "❌ Kubernetes access denied (403 Forbidden).\n" + "You don't have permission to access this cluster or resource.\n\n" + "This could mean:\n" + "1. Your user/role lacks the necessary RBAC permissions\n" + "2. You're connected to the wrong cluster\n" + "3. The cluster's access policies have changed\n\n" + "To fix this:\n" + "1. Verify you're using the correct cluster context\n" + "2. Contact your cluster administrator for access\n" + "3. Ensure your AWS role has the necessary EKS permissions" + ) + + elif error_type == "connection_error": + return ( + "❌ Cannot connect to Kubernetes cluster.\n" + "Network connection to the cluster failed.\n\n" + "This could mean:\n" + "1. The cluster is not accessible from your network\n" + "2. The cluster endpoint URL is incorrect\n" + "3. Network connectivity issues\n\n" + "To fix this:\n" + "1. Check your network connection\n" + "2. Verify the cluster is running and accessible\n" + "3. Reconfigure cluster access:\n" + " hyp set-cluster-context --region " + ) + + else: + return ( + "❌ Kubernetes connection failed.\n" + f"Error: {error_type}\n\n" + "To troubleshoot:\n" + "1. Check your kubeconfig file at ~/.kube/config\n" + "2. Reconfigure cluster access:\n" + " hyp set-cluster-context --region \n" + "3. Try your HyperPod command again" + ) + +def _is_kubernetes_operation(func, **kwargs) -> bool: + """ + Detect if this operation requires Kubernetes connectivity. + """ + try: + # Check function name for Kubernetes-related patterns + func_name = func.__name__.lower() + k8s_patterns = ['logs', 'operator', 'pod', 'describe', 'list', 'delete', 'create'] + + if any(pattern in func_name for pattern in k8s_patterns): + return True + + # Check if wrapped function has Kubernetes patterns + if hasattr(func, '__wrapped__'): + wrapped_name = getattr(func.__wrapped__, '__name__', '').lower() + if any(pattern in wrapped_name for pattern in k8s_patterns): + return True + + # Check Click command info for Kubernetes patterns + try: + click_ctx = click.get_current_context(silent=True) + if click_ctx and hasattr(click_ctx, 'info_name'): + command_path = str(click_ctx.info_name).lower() + if any(pattern in command_path for pattern in k8s_patterns): + return True + except Exception: + pass + + except Exception as e: + logger.debug(f"Failed to detect Kubernetes operation: {e}") + + return False + def handle_cli_exceptions(): """ Template-agnostic decorator with proactive namespace validation and enhanced error handling. @@ -815,6 +1023,15 @@ def wrapper(*args, **kwargs): sys.exit(1) return + # Kubernetes connectivity check for operations that require it + if _is_kubernetes_operation(func, **kwargs): + is_connected, error_type = _check_kubernetes_connectivity() + if not is_connected: + auth_error_message = _generate_kubernetes_auth_error_message(error_type) + click.echo(auth_error_message) + sys.exit(1) + return + # Execute the command try: return func(*args, **kwargs) @@ -828,6 +1045,43 @@ def wrapper(*args, **kwargs): sys.exit(1) return + # 2.1: Enhanced Kubernetes Authentication Error Handling + # Check for 401 Unauthorized errors and provide helpful guidance + if isinstance(e, ApiException) and e.status == 401: + auth_error_message = _generate_kubernetes_auth_error_message("unauthorized") + click.echo(auth_error_message) + sys.exit(1) + return + elif "401" in str(e) and ("unauthorized" in str(e).lower() or "Unauthorized" in str(e)): + auth_error_message = _generate_kubernetes_auth_error_message("unauthorized") + click.echo(auth_error_message) + sys.exit(1) + return + + # 2.2: Enhanced Kubernetes Forbidden Error Handling + elif isinstance(e, ApiException) and e.status == 403: + auth_error_message = _generate_kubernetes_auth_error_message("forbidden") + click.echo(auth_error_message) + sys.exit(1) + return + elif "403" in str(e) and ("forbidden" in str(e).lower() or "Forbidden" in str(e)): + auth_error_message = _generate_kubernetes_auth_error_message("forbidden") + click.echo(auth_error_message) + sys.exit(1) + return + + # 2.3: Enhanced Kubernetes Configuration Error Handling + elif isinstance(e, ConfigException): + if "No configuration found" in str(e): + auth_error_message = _generate_kubernetes_auth_error_message("no_config") + elif "Invalid kube-config" in str(e): + auth_error_message = _generate_kubernetes_auth_error_message("invalid_config") + else: + auth_error_message = _generate_kubernetes_auth_error_message(f"config_error: {str(e)}") + click.echo(auth_error_message) + sys.exit(1) + return + # 3: Enhanced 404 Resource Handling with Dynamic Target Detection # Check if this is a 404 error that can benefit from enhanced handling if isinstance(e, ApiException) and e.status == 404: diff --git a/test/unit_tests/cli/test_inference.py b/test/unit_tests/cli/test_inference.py index 8cf7ccc3..7f455ca8 100644 --- a/test/unit_tests/cli/test_inference.py +++ b/test/unit_tests/cli/test_inference.py @@ -7,6 +7,18 @@ import hyperpod_jumpstart_inference_template.registry as jreg import hyperpod_custom_inference_template.registry as creg +# Mock Kubernetes connectivity for all tests in this module +@pytest.fixture(autouse=True) +def mock_kubernetes_connectivity(): + """Mock Kubernetes connectivity checks for all tests in this module.""" + with patch('sagemaker.hyperpod.common.cli_decorators._check_kubernetes_connectivity') as mock_connectivity, \ + patch('sagemaker.hyperpod.common.cli_decorators._is_kubernetes_operation') as mock_is_k8s_op: + # Always return successful connectivity + mock_connectivity.return_value = (True, "") + # Let the operation detection work normally + mock_is_k8s_op.side_effect = lambda func, **kwargs: True + yield + # Import the non-create commands that don't need special handling from sagemaker.hyperpod.cli.commands.inference import ( js_create, custom_create, custom_invoke, diff --git a/test/unit_tests/cli/test_training.py b/test/unit_tests/cli/test_training.py index 3e3c653c..dfb245f8 100644 --- a/test/unit_tests/cli/test_training.py +++ b/test/unit_tests/cli/test_training.py @@ -13,10 +13,26 @@ import sys import os import importlib +import pytest # Add the hyperpod-pytorch-job-template to the path for testing sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', '..', 'hyperpod-pytorch-job-template')) +# Mock Kubernetes connectivity for all tests in this module +@pytest.fixture(autouse=True) +def mock_kubernetes_connectivity(): + """Mock Kubernetes connectivity checks for all tests in this module.""" + with patch('sagemaker.hyperpod.common.cli_decorators._check_kubernetes_connectivity') as mock_connectivity, \ + patch('sagemaker.hyperpod.common.cli_decorators._is_kubernetes_operation') as mock_is_k8s_op, \ + patch('sagemaker.hyperpod.common.cli_decorators._check_training_operator_exists') as mock_training_op: + # Always return successful connectivity + mock_connectivity.return_value = (True, "") + # Let the operation detection work normally + mock_is_k8s_op.side_effect = lambda func, **kwargs: True + # Always return that training operator exists + mock_training_op.return_value = True + yield + try: from hyperpod_pytorch_job_template.v1_1.model import PyTorchJobConfig, VolumeConfig from pydantic import ValidationError diff --git a/test/unit_tests/error_handling/test_cli_decorators.py b/test/unit_tests/error_handling/test_cli_decorators.py index bdb57c77..3ec193fb 100644 --- a/test/unit_tests/error_handling/test_cli_decorators.py +++ b/test/unit_tests/error_handling/test_cli_decorators.py @@ -21,6 +21,18 @@ _is_get_logs_operation ) +# Mock Kubernetes connectivity for all tests in this module +@pytest.fixture(autouse=True) +def mock_kubernetes_connectivity(): + """Mock Kubernetes connectivity checks for all tests in this module.""" + with patch('sagemaker.hyperpod.common.cli_decorators._check_kubernetes_connectivity') as mock_connectivity, \ + patch('sagemaker.hyperpod.common.cli_decorators._is_kubernetes_operation') as mock_is_k8s_op: + # Always return successful connectivity + mock_connectivity.return_value = (True, "") + # Let the operation detection work normally + mock_is_k8s_op.side_effect = lambda func, **kwargs: True + yield + class TestHandleCliExceptions: """Test template-agnostic handle_cli_exceptions decorator.""" diff --git a/test/unit_tests/test_kubernetes_auth_error_handling.py b/test/unit_tests/test_kubernetes_auth_error_handling.py new file mode 100644 index 00000000..f6dee68c --- /dev/null +++ b/test/unit_tests/test_kubernetes_auth_error_handling.py @@ -0,0 +1,495 @@ +""" +Unit tests for Kubernetes authentication error handling in cli_decorators.py +Tests all authentication scenarios and error message generation. +""" + +import pytest +import sys +import os +from unittest.mock import patch, MagicMock, Mock +from kubernetes.client.exceptions import ApiException +from kubernetes.config.config_exception import ConfigException + +# Add the src directory to the path so we can import the modules +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../src')) + +from sagemaker.hyperpod.common.cli_decorators import ( + _check_kubernetes_connectivity, + _generate_kubernetes_auth_error_message, + _is_kubernetes_operation, + _check_aws_credentials, + _get_current_kubernetes_context, + handle_cli_exceptions +) + + +class TestKubernetesConnectivityCheck: + """Test the _check_kubernetes_connectivity function""" + + @patch('sagemaker.hyperpod.common.cli_decorators.config.load_kube_config') + @patch('kubernetes.client.VersionApi') + def test_successful_connection(self, mock_version_api, mock_load_config): + """Test successful Kubernetes connection""" + mock_version_instance = Mock() + mock_version_api.return_value = mock_version_instance + mock_version_instance.get_code.return_value = {"major": "1", "minor": "28"} + + is_connected, error_type = _check_kubernetes_connectivity() + + assert is_connected is True + assert error_type == "" + mock_load_config.assert_called_once() + mock_version_instance.get_code.assert_called_once() + + @patch('sagemaker.hyperpod.common.cli_decorators.config.load_kube_config') + def test_no_config_found(self, mock_load_config): + """Test ConfigException with no configuration found""" + mock_load_config.side_effect = ConfigException("No configuration found") + + is_connected, error_type = _check_kubernetes_connectivity() + + assert is_connected is False + assert error_type == "no_config" + + @patch('sagemaker.hyperpod.common.cli_decorators.config.load_kube_config') + def test_invalid_config(self, mock_load_config): + """Test ConfigException with invalid kube-config""" + mock_load_config.side_effect = ConfigException("Invalid kube-config file") + + is_connected, error_type = _check_kubernetes_connectivity() + + assert is_connected is False + assert error_type == "invalid_config" + + @patch('sagemaker.hyperpod.common.cli_decorators.config.load_kube_config') + @patch('kubernetes.client.VersionApi') + def test_401_unauthorized(self, mock_version_api, mock_load_config): + """Test 401 Unauthorized ApiException""" + mock_version_instance = Mock() + mock_version_api.return_value = mock_version_instance + mock_version_instance.get_code.side_effect = ApiException(status=401, reason="Unauthorized") + + is_connected, error_type = _check_kubernetes_connectivity() + + assert is_connected is False + assert error_type == "unauthorized" + + @patch('sagemaker.hyperpod.common.cli_decorators.config.load_kube_config') + @patch('kubernetes.client.VersionApi') + def test_403_forbidden(self, mock_version_api, mock_load_config): + """Test 403 Forbidden ApiException""" + mock_version_instance = Mock() + mock_version_api.return_value = mock_version_instance + mock_version_instance.get_code.side_effect = ApiException(status=403, reason="Forbidden") + + is_connected, error_type = _check_kubernetes_connectivity() + + assert is_connected is False + assert error_type == "forbidden" + + @patch('sagemaker.hyperpod.common.cli_decorators.config.load_kube_config') + @patch('kubernetes.client.VersionApi') + def test_connection_error(self, mock_version_api, mock_load_config): + """Test connection timeout error""" + mock_version_instance = Mock() + mock_version_api.return_value = mock_version_instance + mock_version_instance.get_code.side_effect = Exception("connection timeout") + + is_connected, error_type = _check_kubernetes_connectivity() + + assert is_connected is False + assert error_type == "connection_error" + + @patch('sagemaker.hyperpod.common.cli_decorators.config.load_kube_config') + @patch('kubernetes.client.VersionApi') + def test_unauthorized_string_error(self, mock_version_api, mock_load_config): + """Test unauthorized error in string format""" + mock_version_instance = Mock() + mock_version_api.return_value = mock_version_instance + mock_version_instance.get_code.side_effect = Exception("401 unauthorized access") + + is_connected, error_type = _check_kubernetes_connectivity() + + assert is_connected is False + assert error_type == "unauthorized" + + +class TestKubernetesAuthErrorMessages: + """Test the _generate_kubernetes_auth_error_message function""" + + def test_no_config_error_message(self): + """Test error message for no configuration found""" + message = _generate_kubernetes_auth_error_message("no_config") + + assert "❌ Kubernetes configuration not found" in message + assert "hyp set-cluster-context" in message + assert "aws eks update-kubeconfig" not in message + assert "💡 This will set up the necessary Kubernetes configuration" in message + + def test_invalid_config_error_message(self): + """Test error message for invalid configuration""" + message = _generate_kubernetes_auth_error_message("invalid_config") + + assert "❌ Invalid Kubernetes configuration" in message + assert "hyp set-cluster-context" in message + assert "aws eks update-kubeconfig" not in message + assert "💡 This will refresh your cluster configuration" in message + + @patch('sagemaker.hyperpod.common.cli_decorators._get_current_kubernetes_context') + @patch('sagemaker.hyperpod.common.cli_decorators._check_aws_credentials') + def test_unauthorized_with_valid_aws_creds(self, mock_aws_creds, mock_context): + """Test unauthorized error message with valid AWS credentials""" + mock_context.return_value = "my-cluster" + mock_aws_creds.return_value = True + + message = _generate_kubernetes_auth_error_message("unauthorized") + + assert "❌ Kubernetes authentication failed (401 Unauthorized)" in message + assert "Current context: my-cluster" in message + assert "hyp set-cluster-context" in message + assert "Try your HyperPod command again" in message + assert "aws eks update-kubeconfig" not in message + assert "💡 This will refresh your authentication" in message + + @patch('sagemaker.hyperpod.common.cli_decorators._get_current_kubernetes_context') + @patch('sagemaker.hyperpod.common.cli_decorators._check_aws_credentials') + def test_unauthorized_with_invalid_aws_creds(self, mock_aws_creds, mock_context): + """Test unauthorized error message with invalid AWS credentials""" + mock_context.return_value = "my-cluster" + mock_aws_creds.return_value = False + + message = _generate_kubernetes_auth_error_message("unauthorized") + + assert "❌ Kubernetes authentication failed (401 Unauthorized)" in message + assert "🔍 AWS credentials issue detected" in message + assert "aws sts get-caller-identity" in message + assert "hyp set-cluster-context" in message + assert "aws eks update-kubeconfig" not in message + assert "💡 Make sure your AWS credentials have the necessary EKS permissions" in message + + def test_forbidden_error_message(self): + """Test forbidden error message""" + message = _generate_kubernetes_auth_error_message("forbidden") + + assert "❌ Kubernetes access denied (403 Forbidden)" in message + assert "RBAC permissions" in message + assert "kubectl config current-context" not in message + assert "kubectl auth can-i get pods" not in message + assert "Verify you're using the correct cluster context" in message + assert "Contact your cluster administrator for access" in message + + def test_connection_error_message(self): + """Test connection error message""" + message = _generate_kubernetes_auth_error_message("connection_error") + + assert "❌ Cannot connect to Kubernetes cluster" in message + assert "Network connection to the cluster failed" in message + assert "hyp set-cluster-context" in message + assert "aws eks update-kubeconfig" not in message + + def test_generic_error_message(self): + """Test generic error message""" + message = _generate_kubernetes_auth_error_message("some_unknown_error") + + assert "❌ Kubernetes connection failed" in message + assert "Error: some_unknown_error" in message + assert "kubectl config view" not in message + assert "kubectl get nodes" not in message + assert "hyp set-cluster-context" in message + + +class TestKubernetesOperationDetection: + """Test the _is_kubernetes_operation function""" + + def test_logs_operation_detection(self): + """Test detection of logs operations""" + mock_func = Mock() + mock_func.__name__ = "js_get_operator_logs" + + result = _is_kubernetes_operation(mock_func) + + assert result is True + + def test_create_operation_detection(self): + """Test detection of create operations""" + mock_func = Mock() + mock_func.__name__ = "js_create_endpoint" + + result = _is_kubernetes_operation(mock_func) + + assert result is True + + def test_describe_operation_detection(self): + """Test detection of describe operations""" + mock_func = Mock() + mock_func.__name__ = "pytorch_describe_job" + + result = _is_kubernetes_operation(mock_func) + + assert result is True + + def test_non_kubernetes_operation(self): + """Test non-Kubernetes operation detection""" + mock_func = Mock() + mock_func.__name__ = "some_other_function" + + result = _is_kubernetes_operation(mock_func) + + assert result is False + + @patch('sagemaker.hyperpod.common.cli_decorators.click.get_current_context') + def test_click_command_detection(self, mock_get_context): + """Test detection via Click command context""" + mock_func = Mock() + mock_func.__name__ = "some_function" + + mock_context = Mock() + mock_context.info_name = "hyp-get-logs" + mock_get_context.return_value = mock_context + + result = _is_kubernetes_operation(mock_func) + + assert result is True + + def test_wrapped_function_detection(self): + """Test detection of wrapped functions""" + mock_func = Mock() + mock_func.__name__ = "wrapper" + + mock_wrapped = Mock() + mock_wrapped.__name__ = "get_operator_logs" + mock_func.__wrapped__ = mock_wrapped + + result = _is_kubernetes_operation(mock_func) + + assert result is True + + +class TestAWSCredentialsCheck: + """Test the _check_aws_credentials function""" + + @patch('boto3.client') + def test_valid_aws_credentials(self, mock_boto_client): + """Test valid AWS credentials""" + mock_sts = Mock() + mock_boto_client.return_value = mock_sts + mock_sts.get_caller_identity.return_value = {"Account": "123456789012"} + + result = _check_aws_credentials() + + assert result is True + mock_boto_client.assert_called_once_with('sts') + mock_sts.get_caller_identity.assert_called_once() + + @patch('boto3.client') + def test_no_credentials_error(self, mock_boto_client): + """Test NoCredentialsError""" + from botocore.exceptions import NoCredentialsError + + mock_sts = Mock() + mock_boto_client.return_value = mock_sts + mock_sts.get_caller_identity.side_effect = NoCredentialsError() + + result = _check_aws_credentials() + + assert result is False + + @patch('boto3.client') + def test_partial_credentials_error(self, mock_boto_client): + """Test PartialCredentialsError""" + from botocore.exceptions import PartialCredentialsError + + mock_sts = Mock() + mock_boto_client.return_value = mock_sts + mock_sts.get_caller_identity.side_effect = PartialCredentialsError(provider="aws", cred_var="AWS_SECRET_ACCESS_KEY") + + result = _check_aws_credentials() + + assert result is False + + +class TestKubernetesContextRetrieval: + """Test the _get_current_kubernetes_context function""" + + @patch('sagemaker.hyperpod.common.cli_decorators.config.list_kube_config_contexts') + def test_get_current_context_success(self, mock_list_contexts): + """Test successful context retrieval""" + mock_contexts = [{"name": "context1"}, {"name": "context2"}] + mock_active = {"name": "my-cluster"} + mock_list_contexts.return_value = (mock_contexts, mock_active) + + result = _get_current_kubernetes_context() + + assert result == "my-cluster" + + @patch('sagemaker.hyperpod.common.cli_decorators.config.list_kube_config_contexts') + def test_get_current_context_no_active(self, mock_list_contexts): + """Test context retrieval with no active context""" + mock_contexts = [{"name": "context1"}, {"name": "context2"}] + mock_list_contexts.return_value = (mock_contexts, None) + + result = _get_current_kubernetes_context() + + assert result == "none" + + @patch('sagemaker.hyperpod.common.cli_decorators.config.list_kube_config_contexts') + def test_get_current_context_error(self, mock_list_contexts): + """Test context retrieval with error""" + mock_list_contexts.side_effect = Exception("Config error") + + result = _get_current_kubernetes_context() + + assert result == "unknown" + + +class TestDecoratorIntegration: + """Test the handle_cli_exceptions decorator integration""" + + @patch('sagemaker.hyperpod.common.cli_decorators._is_kubernetes_operation') + @patch('sagemaker.hyperpod.common.cli_decorators._check_kubernetes_connectivity') + @patch('sagemaker.hyperpod.common.cli_decorators.click.echo') + def test_decorator_proactive_auth_check_unauthorized(self, mock_echo, mock_connectivity, mock_is_k8s_op): + """Test decorator proactive authentication check for unauthorized error""" + mock_is_k8s_op.return_value = True + mock_connectivity.return_value = (False, "unauthorized") + + @handle_cli_exceptions() + def mock_function(): + return "success" + + mock_function.__name__ = "get_operator_logs" + + # Should exit with sys.exit(1) due to auth failure + with pytest.raises(SystemExit) as exc_info: + mock_function() + + assert exc_info.value.code == 1 + mock_echo.assert_called() + # Verify the error message contains expected content + error_message = mock_echo.call_args[0][0] + assert "❌ Kubernetes authentication failed (401 Unauthorized)" in error_message + assert "hyp set-cluster-context" in error_message + + @patch('sagemaker.hyperpod.common.cli_decorators._is_kubernetes_operation') + @patch('sagemaker.hyperpod.common.cli_decorators._check_kubernetes_connectivity') + def test_decorator_successful_auth_check(self, mock_connectivity, mock_is_k8s_op): + """Test decorator with successful authentication check""" + mock_is_k8s_op.return_value = True + mock_connectivity.return_value = (True, "") + + @handle_cli_exceptions() + def mock_function(): + return "success" + + mock_function.__name__ = "get_operator_logs" + + result = mock_function() + + assert result == "success" + + @patch('sagemaker.hyperpod.common.cli_decorators._is_kubernetes_operation') + @patch('sagemaker.hyperpod.common.cli_decorators._check_kubernetes_connectivity') + @patch('sagemaker.hyperpod.common.cli_decorators.click.echo') + def test_decorator_reactive_401_handling(self, mock_echo, mock_connectivity, mock_is_k8s_op): + """Test decorator reactive handling of 401 ApiException""" + mock_is_k8s_op.return_value = True + mock_connectivity.return_value = (True, "") + + @handle_cli_exceptions() + def mock_function(): + raise ApiException(status=401, reason="Unauthorized") + + mock_function.__name__ = "get_operator_logs" + + with pytest.raises(SystemExit) as exc_info: + mock_function() + + assert exc_info.value.code == 1 + mock_echo.assert_called() + error_message = mock_echo.call_args[0][0] + assert "❌ Kubernetes authentication failed (401 Unauthorized)" in error_message + + @patch('sagemaker.hyperpod.common.cli_decorators._is_kubernetes_operation') + @patch('sagemaker.hyperpod.common.cli_decorators._check_kubernetes_connectivity') + @patch('sagemaker.hyperpod.common.cli_decorators.click.echo') + def test_decorator_reactive_403_handling(self, mock_echo, mock_connectivity, mock_is_k8s_op): + """Test decorator reactive handling of 403 ApiException""" + mock_is_k8s_op.return_value = True + mock_connectivity.return_value = (True, "") + + @handle_cli_exceptions() + def mock_function(): + raise ApiException(status=403, reason="Forbidden") + + mock_function.__name__ = "get_operator_logs" + + with pytest.raises(SystemExit) as exc_info: + mock_function() + + assert exc_info.value.code == 1 + mock_echo.assert_called() + error_message = mock_echo.call_args[0][0] + assert "❌ Kubernetes access denied (403 Forbidden)" in error_message + + @patch('sagemaker.hyperpod.common.cli_decorators._is_kubernetes_operation') + @patch('sagemaker.hyperpod.common.cli_decorators._check_kubernetes_connectivity') + @patch('sagemaker.hyperpod.common.cli_decorators.click.echo') + def test_decorator_config_exception_handling(self, mock_echo, mock_connectivity, mock_is_k8s_op): + """Test decorator handling of ConfigException""" + mock_is_k8s_op.return_value = True + mock_connectivity.return_value = (True, "") + + @handle_cli_exceptions() + def mock_function(): + raise ConfigException("No configuration found") + + mock_function.__name__ = "get_operator_logs" + + with pytest.raises(SystemExit) as exc_info: + mock_function() + + assert exc_info.value.code == 1 + mock_echo.assert_called() + error_message = mock_echo.call_args[0][0] + assert "❌ Kubernetes configuration not found" in error_message + + @patch('sagemaker.hyperpod.common.cli_decorators._is_kubernetes_operation') + @patch('sagemaker.hyperpod.common.cli_decorators._check_kubernetes_connectivity') + @patch('sagemaker.hyperpod.common.cli_decorators.click.echo') + def test_decorator_string_401_handling(self, mock_echo, mock_connectivity, mock_is_k8s_op): + """Test decorator handling of string-based 401 errors""" + mock_is_k8s_op.return_value = True + mock_connectivity.return_value = (True, "") + + @handle_cli_exceptions() + def mock_function(): + raise Exception("401 Unauthorized: the server has asked for the client to provide credentials") + + mock_function.__name__ = "get_operator_logs" + + with pytest.raises(SystemExit) as exc_info: + mock_function() + + assert exc_info.value.code == 1 + mock_echo.assert_called() + error_message = mock_echo.call_args[0][0] + assert "❌ Kubernetes authentication failed (401 Unauthorized)" in error_message + + @patch('sagemaker.hyperpod.common.cli_decorators._is_kubernetes_operation') + def test_decorator_non_kubernetes_operation(self, mock_is_k8s_op): + """Test decorator with non-Kubernetes operation""" + mock_is_k8s_op.return_value = False + + @handle_cli_exceptions() + def mock_function(): + return "success" + + mock_function.__name__ = "some_other_function" + + result = mock_function() + + assert result == "success" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])