Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/sagemaker/hyperpod/cli/commands/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,11 @@ def timeout_handler(signum, frame):
sm_client = get_sagemaker_client(session, botocore_config)
hp_cluster_details = sm_client.describe_cluster(ClusterName=cluster_name)
logger.debug("Fetched hyperpod cluster details")

# Check if cluster is EKS-orchestrated
if "Orchestrator" not in hp_cluster_details or "Eks" not in hp_cluster_details.get("Orchestrator", {}):
raise ValueError(f"Cluster '{cluster_name}' is not EKS-orchestrated. HyperPod CLI only supports EKS-orchestrated clusters.")

store_current_hyperpod_context(hp_cluster_details)
eks_cluster_arn = hp_cluster_details["Orchestrator"]["Eks"]["ClusterArn"]
logger.debug(
Expand Down
7 changes: 6 additions & 1 deletion src/sagemaker/hyperpod/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def setup_logging(logger, debug=False):

def is_eks_orchestrator(sagemaker_client, cluster_name: str):
response = sagemaker_client.describe_cluster(ClusterName=cluster_name)
return "Eks" in response["Orchestrator"]
return response.get("Orchestrator", {}).get("Eks") is not None


def update_kube_config(
Expand Down Expand Up @@ -250,6 +250,9 @@ def set_cluster_context(

client = boto3.client("sagemaker", region_name=region)

if not is_eks_orchestrator(client, cluster_name):
raise ValueError(f"Cluster '{cluster_name}' is not EKS-orchestrated. HyperPod CLI only supports EKS-orchestrated clusters.")

response = client.describe_cluster(ClusterName=cluster_name)
eks_cluster_arn = response["Orchestrator"]["Eks"]["ClusterArn"]
eks_name = get_eks_name_from_arn(eks_cluster_arn)
Expand Down Expand Up @@ -300,6 +303,8 @@ def get_current_cluster():
client = boto3.client("sagemaker", region_name=region)

for cluster_name in hyperpod_clusters:
if not is_eks_orchestrator(client, cluster_name):
continue
response = client.describe_cluster(ClusterName=cluster_name)
if response["Orchestrator"]["Eks"]["ClusterArn"] == current_context:
return cluster_name
Expand Down
7 changes: 6 additions & 1 deletion test/unit_tests/common/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,12 @@ def test_set_cluster_context(self, mock_set_context_func, mock_update_config, mo

set_cluster_context("my-cluster", "us-west-2", "test-namespace")

mock_client.describe_cluster.assert_called_once_with(ClusterName="my-cluster")
# Expect 2 calls: one for is_eks_orchestrator validation, one for getting cluster details
self.assertEqual(mock_client.describe_cluster.call_count, 2)
mock_client.describe_cluster.assert_has_calls([
call(ClusterName="my-cluster"),
call(ClusterName="my-cluster")
])
mock_get_name.assert_called_once()
mock_update_config.assert_called_once()
mock_set_context_func.assert_called_once()
Expand Down
Loading