Skip to content

Commit

Permalink
Fix OOB for IPython and refactor. Closes kubeflow#10075.
Browse files Browse the repository at this point in the history
  • Loading branch information
gkcalat committed Oct 13, 2023
1 parent 5c44143 commit 37124b1
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 39 deletions.
15 changes: 14 additions & 1 deletion sdk/python/kfp/client/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,8 @@ def get_auth_code(client_id: str) -> Tuple[str, str]:
'scope=openid%20email&access_type=offline&'
f'redirect_uri={redirect_uri}')
authorization_response = None
if ('SSH_CONNECTION' in os.environ) or ('SSH_CLIENT' in os.environ):
if ('SSH_CONNECTION' in os.environ) or ('SSH_CLIENT'
in os.environ) or is_ipython():
try:
print((
'SSH connection detected. Please follow the instructions below.'
Expand Down Expand Up @@ -509,3 +510,15 @@ def fetch_auth_token_from_response(url: str) -> str:
if isinstance(access_code, list):
access_code = str(access_code.pop(0))
return access_code


def is_ipython() -> bool:
"""Returns whether we are running in notebook."""
try:
import IPython
ipy = IPython.get_ipython()
if ipy is None:
return False
except ImportError:
return False
return True
107 changes: 107 additions & 0 deletions sdk/python/kfp/client/auth_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright 2023 The Kubeflow Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from unittest.mock import MagicMock
from unittest.mock import patch

from absl.testing import parameterized
from kfp.client import auth


class TestAuth(parameterized.TestCase):

def test_is_ipython_return_false(self):
mock = MagicMock()
with patch.dict('sys.modules', IPython=mock):
mock.get_ipython.return_value = None
self.assertFalse(auth.is_ipython())

def test_is_ipython_return_true(self):
mock = MagicMock()
with patch.dict('sys.modules', IPython=mock):
mock.get_ipython.return_value = 'Something'
self.assertTrue(auth.is_ipython())

def test_is_ipython_should_raise_error(self):
mock = MagicMock()
with patch.dict('sys.modules', mock):
mock.side_effect = ImportError
self.assertFalse(auth.is_ipython())

@patch('builtins.input', lambda *args:
'https://oauth2.example.com/auth?code=4/P7q7W91a-oMsCeLvIaQm6bTrgtp7'
)
@patch('kfp.client.auth.is_ipython', lambda *args: True)
def test_get_auth_code_from_ipython(self):
os.environ.pop('SSH_CONNECTION', None)
os.environ.pop('SSH_CLIENT', None)
token, redirect_uri = auth.get_auth_code('sample-client-id')
self.assertEqual(token, '4/P7q7W91a-oMsCeLvIaQm6bTrgtp7')
self.assertEqual(redirect_uri, 'http://localhost:9901')

@patch('builtins.input', lambda *args:
'https://oauth2.example.com/auth?code=4/P7q7W91a-oMsCeLvIaQm6bTrgtp7'
)
@patch('kfp.client.auth.is_ipython', lambda *args: False)
def test_get_auth_code_from_remote_connection(self):
os.environ['SSH_CONNECTION'] = 'ENABLED'
os.environ.pop('SSH_CLIENT', None)
token, redirect_uri = auth.get_auth_code('sample-client-id')
self.assertEqual(token, '4/P7q7W91a-oMsCeLvIaQm6bTrgtp7')
self.assertEqual(redirect_uri, 'http://localhost:9901')

@patch('builtins.input', lambda *args:
'https://oauth2.example.com/auth?code=4/P7q7W91a-oMsCeLvIaQm6bTrgtp7'
)
@patch('kfp.client.auth.is_ipython', lambda *args: False)
def test_get_auth_code_from_remote_client(self):
os.environ.pop('SSH_CONNECTION', None)
os.environ['SSH_CLIENT'] = 'ENABLED'
token, redirect_uri = auth.get_auth_code('sample-client-id')
self.assertEqual(token, '4/P7q7W91a-oMsCeLvIaQm6bTrgtp7')
self.assertEqual(redirect_uri, 'http://localhost:9901')

@patch('builtins.input', lambda *args: 'https://oauth2.example.com/auth')
@patch('kfp.client.auth.is_ipython', lambda *args: False)
def test_get_auth_code_from_remote_client_missing_code(self):
os.environ.pop('SSH_CONNECTION', None)
os.environ['SSH_CLIENT'] = 'ENABLED'
self.assertRaises(KeyError, auth.get_auth_code, 'sample-client-id')

@patch('kfp.client.auth.get_auth_response_local', lambda *args:
'https://oauth2.example.com/auth?code=4/P7q7W91a-oMsCeLvIaQm6bTrgtp7'
)
@patch('kfp.client.auth.is_ipython', lambda *args: False)
def test_get_auth_code_from_local(self):
os.environ.pop('SSH_CONNECTION', None)
os.environ.pop('SSH_CLIENT', None)
token, redirect_uri = auth.get_auth_code('sample-client-id')
self.assertEqual(token, '4/P7q7W91a-oMsCeLvIaQm6bTrgtp7')
self.assertEqual(redirect_uri, 'http://localhost:9901')

@patch('kfp.client.auth.get_auth_response_local', lambda *args: None)
@patch('kfp.client.auth.is_ipython', lambda *args: False)
def test_get_auth_code_from_local_empty_response(self):
os.environ.pop('SSH_CONNECTION', None)
os.environ.pop('SSH_CLIENT', None)
self.assertRaises(ValueError, auth.get_auth_code, 'sample-client-id')

@patch('kfp.client.auth.get_auth_response_local',
lambda *args: 'this-is-an-invalid-response')
@patch('kfp.client.auth.is_ipython', lambda *args: False)
def test_get_auth_code_from_local_invalid_response(self):
os.environ.pop('SSH_CONNECTION', None)
os.environ.pop('SSH_CLIENT', None)
self.assertRaises(KeyError, auth.get_auth_code, 'sample-client-id')
20 changes: 4 additions & 16 deletions sdk/python/kfp/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,18 +324,6 @@ def _load_config(
def _is_inverse_proxy_host(self, host: str) -> bool:
return bool(re.match(r'\S+.googleusercontent.com/{0,1}$', host))

def _is_ipython(self) -> bool:
"""Returns whether we are running in notebook."""
try:
import IPython
ipy = IPython.get_ipython()
if ipy is None:
return False
except ImportError:
return False

return True

def _get_url_prefix(self) -> str:
if self._uihost:
# User's own connection.
Expand Down Expand Up @@ -488,7 +476,7 @@ def create_experiment(
experiment = self._experiment_api.create_experiment(body=experiment)

link = f'{self._get_url_prefix()}/#/experiments/details/{experiment.experiment_id}'
if self._is_ipython():
if auth.is_ipython():
import IPython
html = f'<a href="{link}" target="_blank" >Experiment details</a>.'
IPython.display.display(IPython.display.HTML(html))
Expand Down Expand Up @@ -744,7 +732,7 @@ def run_pipeline(
response = self._run_api.create_run(body=run_body)

link = f'{self._get_url_prefix()}/#/runs/details/{response.run_id}'
if self._is_ipython():
if auth.is_ipython():
import IPython
html = (f'<a href="{link}" target="_blank" >Run details</a>.')
IPython.display.display(IPython.display.HTML(html))
Expand Down Expand Up @@ -1424,7 +1412,7 @@ def upload_pipeline(
description=description,
namespace=namespace)
link = f'{self._get_url_prefix()}/#/pipelines/details/{response.pipeline_id}'
if self._is_ipython():
if auth.is_ipython():
import IPython
html = f'<a href="{link}" target="_blank" >Pipeline details</a>.'
IPython.display.display(IPython.display.HTML(html))
Expand Down Expand Up @@ -1473,7 +1461,7 @@ def upload_pipeline_version(
pipeline_package_path, **kwargs)

link = f'{self._get_url_prefix()}/#/pipelines/details/{response.pipeline_id}/version/{response.pipeline_version_id}'
if self._is_ipython():
if auth.is_ipython():
import IPython
html = f'<a href="{link}" target="_blank" >Pipeline details</a>.'
IPython.display.display(IPython.display.HTML(html))
Expand Down
26 changes: 4 additions & 22 deletions sdk/python/kfp/client/client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
import tempfile
import textwrap
import unittest
from unittest.mock import MagicMock
from unittest.mock import Mock
from unittest.mock import patch

from absl.testing import parameterized
from google.protobuf import json_format
from kfp.client import auth
from kfp.client import client
from kfp.compiler import Compiler
from kfp.dsl import component
Expand Down Expand Up @@ -194,24 +194,6 @@ class TestClient(parameterized.TestCase):
def setUp(self):
self.client = client.Client(namespace='ns1')

def test__is_ipython_return_false(self):
mock = MagicMock()
with patch.dict('sys.modules', IPython=mock):
mock.get_ipython.return_value = None
self.assertFalse(self.client._is_ipython())

def test__is_ipython_return_true(self):
mock = MagicMock()
with patch.dict('sys.modules', IPython=mock):
mock.get_ipython.return_value = 'Something'
self.assertTrue(self.client._is_ipython())

def test__is_ipython_should_raise_error(self):
mock = MagicMock()
with patch.dict('sys.modules', mock):
mock.side_effect = ImportError
self.assertFalse(self.client._is_ipython())

def test_wait_for_run_completion_invalid_token_should_raise_error(self):
with self.assertRaises(kfp_server_api.ApiException):
with patch.object(
Expand Down Expand Up @@ -371,7 +353,7 @@ def pipeline_test_upload_without_name(boolean: bool = True):

with patch.object(self.client._upload_api,
'upload_pipeline') as mock_upload_pipeline:
with patch.object(self.client, '_is_ipython', return_value=False):
with patch.object(auth, 'is_ipython', return_value=False):
with tempfile.TemporaryDirectory() as tmp_path:
pipeline_test_path = os.path.join(tmp_path, 'test.yaml')
Compiler().compile(
Expand Down Expand Up @@ -401,7 +383,7 @@ def pipeline_test_upload_without_name(boolean: bool = True):
def test_upload_pipeline_with_name(self, pipeline_name):
with patch.object(self.client._upload_api,
'upload_pipeline') as mock_upload_pipeline:
with patch.object(self.client, '_is_ipython', return_value=False):
with patch.object(auth, 'is_ipython', return_value=False):
self.client.upload_pipeline(
pipeline_package_path='fake.yaml',
pipeline_name=pipeline_name,
Expand All @@ -421,7 +403,7 @@ def test_upload_pipeline_with_name(self, pipeline_name):
def test_upload_pipeline_with_name_invalid(self, pipeline_name):
with patch.object(self.client._upload_api,
'upload_pipeline') as mock_upload_pipeline:
with patch.object(self.client, '_is_ipython', return_value=False):
with patch.object(auth, 'is_ipython', return_value=False):
with self.assertRaisesRegex(
ValueError,
'Invalid pipeline name. Pipeline name cannot be empty or contain only whitespace.'
Expand Down

0 comments on commit 37124b1

Please sign in to comment.