From 3c7283377a4431bc53d2dc2319b61a1f36d5818b Mon Sep 17 00:00:00 2001 From: Pratibha Shrivastav Date: Mon, 25 Aug 2025 11:58:18 +0530 Subject: [PATCH 1/2] port 22 compute connect fix --- .../azext_mlv2/manual/custom/_ssh_command.py | 6 ++++-- .../azext_mlv2/manual/custom/_ssh_connector.py | 8 +++++++- .../azext_mlv2/manual/custom/compute.py | 4 ++-- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/src/machinelearningservices/azext_mlv2/manual/custom/_ssh_command.py b/src/machinelearningservices/azext_mlv2/manual/custom/_ssh_command.py index 58498d22c84..177491ad548 100644 --- a/src/machinelearningservices/azext_mlv2/manual/custom/_ssh_command.py +++ b/src/machinelearningservices/azext_mlv2/manual/custom/_ssh_command.py @@ -24,7 +24,8 @@ def get_ssh_command( services_dict: Dict[str, ServiceInstance], node_index: int, private_key_file_path: str, - ssh_args: Optional[Sequence[str]] = None + ssh_args: Optional[Sequence[str]] = None, + connector_args: Optional[Sequence[str]] = None ) -> Tuple[bool, str]: proxyEndpoint = _get_proxy_endpoint(services_dict, node_index).replace("", str(node_index)) connect_ssh_path = pathlib.Path(__file__).parent / "_ssh_connector.py" @@ -45,9 +46,10 @@ def get_ssh_command( identity_param = " -i {}".format(private_key_file_path) if private_key_file_path else "" # TODO: Find how to enable debug mode ssh_args_str = " ".join(ssh_args) if ssh_args else "" + connector_args_str = " ".join(connector_args) if connector_args else "" return ( connect_ssh_path_has_space, - f'{ssh_path} -v -o ProxyCommand="{sys.executable} {connect_ssh_path} {proxyEndpoint}" ' + f'{ssh_path} -v -o ProxyCommand="{sys.executable} {connect_ssh_path} {proxyEndpoint} {connector_args_str}" ' f"azureuser@{proxyEndpoint}{identity_param}{ssh_args_str}", ) diff --git a/src/machinelearningservices/azext_mlv2/manual/custom/_ssh_connector.py b/src/machinelearningservices/azext_mlv2/manual/custom/_ssh_connector.py index 7d2f39f8961..b49fab5e917 100644 --- a/src/machinelearningservices/azext_mlv2/manual/custom/_ssh_connector.py +++ b/src/machinelearningservices/azext_mlv2/manual/custom/_ssh_connector.py @@ -61,12 +61,18 @@ async def _connect_ssh(self): ) raise Exception(msg) # pylint: disable=broad-exception-raised proxy_endpoint = sys.argv[1] + + is_compute = len(sys.argv) > 2 and sys.argv[2] == "--is-compute" + + uri = f"{proxy_endpoint}/nbip/v1.0/ws-tcp" + if is_compute: + uri += "/port/22" mgtScope = ["https://management.core.windows.net/.default"] aml_token = run_az_cli(["account", "get-access-token", "--scope", mgtScope[0]])["accessToken"] async with websockets.client.connect( - uri=f"{proxy_endpoint}/nbip/v1.0/ws-tcp", + uri=uri, extra_headers={"Authorization": f"Bearer {aml_token}"}, ) as websocket: diff --git a/src/machinelearningservices/azext_mlv2/manual/custom/compute.py b/src/machinelearningservices/azext_mlv2/manual/custom/compute.py index a5944007288..4a16d9825cc 100644 --- a/src/machinelearningservices/azext_mlv2/manual/custom/compute.py +++ b/src/machinelearningservices/azext_mlv2/manual/custom/compute.py @@ -274,12 +274,12 @@ def ml_compute_connect_ssh(cmd, resource_group_name, workspace_name, name, priva # create proxy endpoint for CI based on endpoint for jupyter # TODO: Improve with a call to get proxyendpoint from CI, requires API jupyter = [f["endpoint_uri"] for f in compute.services if f["display_name"] == "Jupyter"][0] - proxyEndpoint = jupyter.replace(name, f"{name}-22").replace("https://", "wss://").replace("/tree/", "") + proxyEndpoint = jupyter.replace("https://", "wss://").replace("/tree/", "") services_dict = { "ssh": ServiceInstance(type="SSH", status="Running", properties={"ProxyEndpoint": proxyEndpoint}) } - path_has_space, ssh_command = get_ssh_command(services_dict, 0, private_key_file_path) + path_has_space, ssh_command = get_ssh_command(services_dict, 0, private_key_file_path, connector_args=["--is-compute"]) print(f"ssh_command: {ssh_command}") if path_has_space: module_logger.error(ssh_connector_file_path_space_message()) From 75fadcbe9b763d91e1036b6e96333e41cc4c1840 Mon Sep 17 00:00:00 2001 From: Pratibha Shrivastav Date: Mon, 25 Aug 2025 12:10:01 +0530 Subject: [PATCH 2/2] add changelog --- src/machinelearningservices/CHANGELOG.rst | 2 ++ .../azext_mlv2/manual/custom/_ssh_connector.py | 1 - 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/machinelearningservices/CHANGELOG.rst b/src/machinelearningservices/CHANGELOG.rst index 1fb194c0002..6d786917f80 100644 --- a/src/machinelearningservices/CHANGELOG.rst +++ b/src/machinelearningservices/CHANGELOG.rst @@ -1,6 +1,8 @@ ## Azure Machine Learning CLI (v2) (unreleased) - `az ml compute update` - Fix a bug compute update which caused Enable SSO property to reset. +- `az ml compute connect-ssh` + - Fix proxy endpoint path ## 2025-05-15 diff --git a/src/machinelearningservices/azext_mlv2/manual/custom/_ssh_connector.py b/src/machinelearningservices/azext_mlv2/manual/custom/_ssh_connector.py index b49fab5e917..20611054143 100644 --- a/src/machinelearningservices/azext_mlv2/manual/custom/_ssh_connector.py +++ b/src/machinelearningservices/azext_mlv2/manual/custom/_ssh_connector.py @@ -63,7 +63,6 @@ async def _connect_ssh(self): proxy_endpoint = sys.argv[1] is_compute = len(sys.argv) > 2 and sys.argv[2] == "--is-compute" - uri = f"{proxy_endpoint}/nbip/v1.0/ws-tcp" if is_compute: uri += "/port/22"