Skip to content
254 changes: 254 additions & 0 deletions src/sagemaker/hyperpod/common/cli_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
import click
import functools
import logging
import subprocess
import os
from kubernetes.client.exceptions import ApiException
from kubernetes import config
from kubernetes.config.config_exception import ConfigException

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -761,6 +765,210 @@ def _check_resources_exist(raw_resource_type: str, namespace: str) -> bool:
logger.debug(f"Failed to check resource existence for {raw_resource_type}: {e}")
return None


def _check_kubernetes_connectivity() -> tuple[bool, str]:
"""
Check if we can connect to Kubernetes cluster.
Returns (is_connected, error_message).
"""
try:
# Try to load kubeconfig and make a simple API call
config.load_kube_config()
from kubernetes import client

# Try to get cluster version - this is a lightweight call that requires authentication
version_api = client.VersionApi()
version_api.get_code()

return True, ""
except ConfigException as e:
if "No configuration found" in str(e):
return False, "no_config"
elif "Invalid kube-config" in str(e):
return False, "invalid_config"
else:
return False, f"config_error: {str(e)}"
except ApiException as e:
if e.status == 401:
return False, "unauthorized"
elif e.status == 403:
return False, "forbidden"
else:
return False, f"api_error: {e.status} {e.reason}"
except Exception as e:
error_str = str(e).lower()
if "unauthorized" in error_str or "401" in error_str:
return False, "unauthorized"
elif "forbidden" in error_str or "403" in error_str:
return False, "forbidden"
elif "connection" in error_str or "timeout" in error_str:
return False, "connection_error"
else:
return False, f"unknown_error: {str(e)}"

def _check_aws_credentials() -> bool:
"""
Check if AWS credentials are available and valid.
"""
try:
import boto3
from botocore.exceptions import NoCredentialsError, PartialCredentialsError

# Try to get caller identity
sts = boto3.client('sts')
sts.get_caller_identity()
return True
except (NoCredentialsError, PartialCredentialsError):
return False
except Exception as e:
logger.debug(f"AWS credentials check failed: {e}")
return False

def _get_current_kubernetes_context() -> str:
"""
Get current Kubernetes context name.
"""
try:
contexts, active_context = config.list_kube_config_contexts()
if active_context:
return active_context.get('name', 'unknown')
return 'none'
except Exception as e:
logger.debug(f"Failed to get current context: {e}")
return 'unknown'

def _generate_kubernetes_auth_error_message(error_type: str) -> str:
"""
Generate helpful error message for Kubernetes authentication issues.
"""
if error_type == "no_config":
return (
"❌ Kubernetes configuration not found.\n"
"No kubeconfig file found. Please ensure you have:\n"
"1. A valid kubeconfig file at ~/.kube/config, or\n"
"2. Set the KUBECONFIG environment variable\n\n"
"To configure cluster access:\n"
" hyp set-cluster-context <cluster-name> --region <region>\n\n"
"💡 This will set up the necessary Kubernetes configuration for your HyperPod cluster."
)

elif error_type == "invalid_config":
return (
"❌ Invalid Kubernetes configuration.\n"
"Your kubeconfig file appears to be corrupted or invalid.\n\n"
"To fix this:\n"
"1. Check your kubeconfig file at ~/.kube/config\n"
"2. Reconfigure cluster access:\n"
" hyp set-cluster-context <cluster-name> --region <region>\n\n"
"💡 This will refresh your cluster configuration with the correct settings."
)

elif error_type == "unauthorized":
current_context = _get_current_kubernetes_context()
aws_creds_valid = _check_aws_credentials()

message = (
"❌ Kubernetes authentication failed (401 Unauthorized).\n"
f"Current context: {current_context}\n\n"
"This usually means your credentials have expired or are invalid.\n\n"
)

if not aws_creds_valid:
message += (
"🔍 AWS credentials issue detected:\n"
"Your AWS credentials appear to be missing or invalid.\n\n"
"To fix this:\n"
"1. Check your AWS credentials:\n"
" aws sts get-caller-identity\n"
"2. If expired, refresh your AWS credentials\n"
"3. Then reconfigure cluster access:\n"
" hyp set-cluster-context <cluster-name> --region <region>\n\n"
"💡 Make sure your AWS credentials have the necessary EKS permissions."
)
else:
message += (
"To fix this:\n"
"1. Reconfigure cluster access:\n"
" hyp set-cluster-context <cluster-name> --region <region>\n"
"2. Try your HyperPod command again\n\n"
"💡 This will refresh your authentication with the cluster."
)

return message

elif error_type == "forbidden":
return (
"❌ Kubernetes access denied (403 Forbidden).\n"
"You don't have permission to access this cluster or resource.\n\n"
"This could mean:\n"
"1. Your user/role lacks the necessary RBAC permissions\n"
"2. You're connected to the wrong cluster\n"
"3. The cluster's access policies have changed\n\n"
"To fix this:\n"
"1. Verify you're using the correct cluster context\n"
"2. Contact your cluster administrator for access\n"
"3. Ensure your AWS role has the necessary EKS permissions"
)

elif error_type == "connection_error":
return (
"❌ Cannot connect to Kubernetes cluster.\n"
"Network connection to the cluster failed.\n\n"
"This could mean:\n"
"1. The cluster is not accessible from your network\n"
"2. The cluster endpoint URL is incorrect\n"
"3. Network connectivity issues\n\n"
"To fix this:\n"
"1. Check your network connection\n"
"2. Verify the cluster is running and accessible\n"
"3. Reconfigure cluster access:\n"
" hyp set-cluster-context <cluster-name> --region <region>"
)

else:
return (
"❌ Kubernetes connection failed.\n"
f"Error: {error_type}\n\n"
"To troubleshoot:\n"
"1. Check your kubeconfig file at ~/.kube/config\n"
"2. Reconfigure cluster access:\n"
" hyp set-cluster-context <cluster-name> --region <region>\n"
"3. Try your HyperPod command again"
)

def _is_kubernetes_operation(func, **kwargs) -> bool:
"""
Detect if this operation requires Kubernetes connectivity.
"""
try:
# Check function name for Kubernetes-related patterns
func_name = func.__name__.lower()
k8s_patterns = ['logs', 'operator', 'pod', 'describe', 'list', 'delete', 'create']

if any(pattern in func_name for pattern in k8s_patterns):
return True

# Check if wrapped function has Kubernetes patterns
if hasattr(func, '__wrapped__'):
wrapped_name = getattr(func.__wrapped__, '__name__', '').lower()
if any(pattern in wrapped_name for pattern in k8s_patterns):
return True

# Check Click command info for Kubernetes patterns
try:
click_ctx = click.get_current_context(silent=True)
if click_ctx and hasattr(click_ctx, 'info_name'):
command_path = str(click_ctx.info_name).lower()
if any(pattern in command_path for pattern in k8s_patterns):
return True
except Exception:
pass

except Exception as e:
logger.debug(f"Failed to detect Kubernetes operation: {e}")

return False

def handle_cli_exceptions():
"""
Template-agnostic decorator with proactive namespace validation and enhanced error handling.
Expand Down Expand Up @@ -815,6 +1023,15 @@ def wrapper(*args, **kwargs):
sys.exit(1)
return

# Kubernetes connectivity check for operations that require it
if _is_kubernetes_operation(func, **kwargs):
is_connected, error_type = _check_kubernetes_connectivity()
if not is_connected:
auth_error_message = _generate_kubernetes_auth_error_message(error_type)
click.echo(auth_error_message)
sys.exit(1)
return

# Execute the command
try:
return func(*args, **kwargs)
Expand All @@ -828,6 +1045,43 @@ def wrapper(*args, **kwargs):
sys.exit(1)
return

# 2.1: Enhanced Kubernetes Authentication Error Handling
# Check for 401 Unauthorized errors and provide helpful guidance
if isinstance(e, ApiException) and e.status == 401:
auth_error_message = _generate_kubernetes_auth_error_message("unauthorized")
click.echo(auth_error_message)
sys.exit(1)
return
elif "401" in str(e) and ("unauthorized" in str(e).lower() or "Unauthorized" in str(e)):
auth_error_message = _generate_kubernetes_auth_error_message("unauthorized")
click.echo(auth_error_message)
sys.exit(1)
return

# 2.2: Enhanced Kubernetes Forbidden Error Handling
elif isinstance(e, ApiException) and e.status == 403:
auth_error_message = _generate_kubernetes_auth_error_message("forbidden")
click.echo(auth_error_message)
sys.exit(1)
return
elif "403" in str(e) and ("forbidden" in str(e).lower() or "Forbidden" in str(e)):
auth_error_message = _generate_kubernetes_auth_error_message("forbidden")
click.echo(auth_error_message)
sys.exit(1)
return

# 2.3: Enhanced Kubernetes Configuration Error Handling
elif isinstance(e, ConfigException):
if "No configuration found" in str(e):
auth_error_message = _generate_kubernetes_auth_error_message("no_config")
elif "Invalid kube-config" in str(e):
auth_error_message = _generate_kubernetes_auth_error_message("invalid_config")
else:
auth_error_message = _generate_kubernetes_auth_error_message(f"config_error: {str(e)}")
click.echo(auth_error_message)
sys.exit(1)
return

# 3: Enhanced 404 Resource Handling with Dynamic Target Detection
# Check if this is a 404 error that can benefit from enhanced handling
if isinstance(e, ApiException) and e.status == 404:
Expand Down
12 changes: 12 additions & 0 deletions test/unit_tests/cli/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,18 @@
import hyperpod_jumpstart_inference_template.registry as jreg
import hyperpod_custom_inference_template.registry as creg

# Mock Kubernetes connectivity for all tests in this module
@pytest.fixture(autouse=True)
def mock_kubernetes_connectivity():
"""Mock Kubernetes connectivity checks for all tests in this module."""
with patch('sagemaker.hyperpod.common.cli_decorators._check_kubernetes_connectivity') as mock_connectivity, \
patch('sagemaker.hyperpod.common.cli_decorators._is_kubernetes_operation') as mock_is_k8s_op:
# Always return successful connectivity
mock_connectivity.return_value = (True, "")
# Let the operation detection work normally
mock_is_k8s_op.side_effect = lambda func, **kwargs: True
yield

# Import the non-create commands that don't need special handling
from sagemaker.hyperpod.cli.commands.inference import (
js_create, custom_create, custom_invoke,
Expand Down
16 changes: 16 additions & 0 deletions test/unit_tests/cli/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,26 @@
import sys
import os
import importlib
import pytest

# Add the hyperpod-pytorch-job-template to the path for testing
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', '..', 'hyperpod-pytorch-job-template'))

# Mock Kubernetes connectivity for all tests in this module
@pytest.fixture(autouse=True)
def mock_kubernetes_connectivity():
"""Mock Kubernetes connectivity checks for all tests in this module."""
with patch('sagemaker.hyperpod.common.cli_decorators._check_kubernetes_connectivity') as mock_connectivity, \
patch('sagemaker.hyperpod.common.cli_decorators._is_kubernetes_operation') as mock_is_k8s_op, \
patch('sagemaker.hyperpod.common.cli_decorators._check_training_operator_exists') as mock_training_op:
# Always return successful connectivity
mock_connectivity.return_value = (True, "")
# Let the operation detection work normally
mock_is_k8s_op.side_effect = lambda func, **kwargs: True
# Always return that training operator exists
mock_training_op.return_value = True
yield

try:
from hyperpod_pytorch_job_template.v1_1.model import PyTorchJobConfig, VolumeConfig
from pydantic import ValidationError
Expand Down
12 changes: 12 additions & 0 deletions test/unit_tests/error_handling/test_cli_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,18 @@
_is_get_logs_operation
)

# Mock Kubernetes connectivity for all tests in this module
@pytest.fixture(autouse=True)
def mock_kubernetes_connectivity():
"""Mock Kubernetes connectivity checks for all tests in this module."""
with patch('sagemaker.hyperpod.common.cli_decorators._check_kubernetes_connectivity') as mock_connectivity, \
patch('sagemaker.hyperpod.common.cli_decorators._is_kubernetes_operation') as mock_is_k8s_op:
# Always return successful connectivity
mock_connectivity.return_value = (True, "")
# Let the operation detection work normally
mock_is_k8s_op.side_effect = lambda func, **kwargs: True
yield


class TestHandleCliExceptions:
"""Test template-agnostic handle_cli_exceptions decorator."""
Expand Down
Loading
Loading