diff --git a/src/sagemaker/hyperpod/cli/commands/cluster.py b/src/sagemaker/hyperpod/cli/commands/cluster.py index cb19f24c..c01142eb 100644 --- a/src/sagemaker/hyperpod/cli/commands/cluster.py +++ b/src/sagemaker/hyperpod/cli/commands/cluster.py @@ -584,6 +584,11 @@ def timeout_handler(signum, frame): sm_client = get_sagemaker_client(session, botocore_config) hp_cluster_details = sm_client.describe_cluster(ClusterName=cluster_name) logger.debug("Fetched hyperpod cluster details") + + # Check if cluster is EKS-orchestrated + if "Orchestrator" not in hp_cluster_details or "Eks" not in hp_cluster_details.get("Orchestrator", {}): + raise ValueError(f"Cluster '{cluster_name}' is not EKS-orchestrated. HyperPod CLI only supports EKS-orchestrated clusters.") + store_current_hyperpod_context(hp_cluster_details) eks_cluster_arn = hp_cluster_details["Orchestrator"]["Eks"]["ClusterArn"] logger.debug( diff --git a/src/sagemaker/hyperpod/common/utils.py b/src/sagemaker/hyperpod/common/utils.py index 0a25b974..15e73ba8 100644 --- a/src/sagemaker/hyperpod/common/utils.py +++ b/src/sagemaker/hyperpod/common/utils.py @@ -159,7 +159,7 @@ def setup_logging(logger, debug=False): def is_eks_orchestrator(sagemaker_client, cluster_name: str): response = sagemaker_client.describe_cluster(ClusterName=cluster_name) - return "Eks" in response["Orchestrator"] + return response.get("Orchestrator", {}).get("Eks") is not None def update_kube_config( @@ -250,6 +250,9 @@ def set_cluster_context( client = boto3.client("sagemaker", region_name=region) + if not is_eks_orchestrator(client, cluster_name): + raise ValueError(f"Cluster '{cluster_name}' is not EKS-orchestrated. HyperPod CLI only supports EKS-orchestrated clusters.") + response = client.describe_cluster(ClusterName=cluster_name) eks_cluster_arn = response["Orchestrator"]["Eks"]["ClusterArn"] eks_name = get_eks_name_from_arn(eks_cluster_arn) @@ -300,6 +303,8 @@ def get_current_cluster(): client = boto3.client("sagemaker", region_name=region) for cluster_name in hyperpod_clusters: + if not is_eks_orchestrator(client, cluster_name): + continue response = client.describe_cluster(ClusterName=cluster_name) if response["Orchestrator"]["Eks"]["ClusterArn"] == current_context: return cluster_name diff --git a/test/unit_tests/common/test_utils.py b/test/unit_tests/common/test_utils.py index 81fc7930..7ba025b3 100644 --- a/test/unit_tests/common/test_utils.py +++ b/test/unit_tests/common/test_utils.py @@ -389,7 +389,12 @@ def test_set_cluster_context(self, mock_set_context_func, mock_update_config, mo set_cluster_context("my-cluster", "us-west-2", "test-namespace") - mock_client.describe_cluster.assert_called_once_with(ClusterName="my-cluster") + # Expect 2 calls: one for is_eks_orchestrator validation, one for getting cluster details + self.assertEqual(mock_client.describe_cluster.call_count, 2) + mock_client.describe_cluster.assert_has_calls([ + call(ClusterName="my-cluster"), + call(ClusterName="my-cluster") + ]) mock_get_name.assert_called_once() mock_update_config.assert_called_once() mock_set_context_func.assert_called_once()