Skip to content

Commit

Permalink
feat: add support for HTTPS URI pipeline templates
Browse files Browse the repository at this point in the history
  • Loading branch information
TheMichaelHu committed Sep 21, 2022
1 parent 1cda4b4 commit 45e0d9f
Show file tree
Hide file tree
Showing 5 changed files with 167 additions and 36 deletions.
3 changes: 3 additions & 0 deletions google/cloud/aiplatform/constants/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@
# Pattern for an Artifact Registry URL.
_VALID_AR_URL = re.compile(r"^https:\/\/([\w-]+)-kfp\.pkg\.dev\/.*")

# Pattern for any JSON or YAML file over HTTPS.
_VALID_HTTPS_URL = re.compile(r"^https:\/\/([\.\/\w-]+)\/.*(json|yaml|yml)$")

# Fields to include in returned PipelineJob when enable_simple_view=True in PipelineJob.list()
_READ_MASK_FIELDS = [
"name",
Expand Down
9 changes: 6 additions & 3 deletions google/cloud/aiplatform/pipeline_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@
# Pattern for an Artifact Registry URL.
_VALID_AR_URL = pipeline_constants._VALID_AR_URL

# Pattern for any JSON or YAML file over HTTPS.
_VALID_HTTPS_URL = pipeline_constants._VALID_HTTPS_URL

_READ_MASK_FIELDS = pipeline_constants._READ_MASK_FIELDS


Expand Down Expand Up @@ -131,8 +134,8 @@ def __init__(
template_path (str):
Required. The path of PipelineJob or PipelineSpec JSON or YAML file. It
can be a local path, a Google Cloud Storage URI (e.g. "gs://project.name"),
or an Artifact Registry URI (e.g.
"https://us-central1-kfp.pkg.dev/proj/repo/pack/latest").
an Artifact Registry URI (e.g.
"https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"), or an HTTPS URI.
job_id (str):
Optional. The unique ID of the job run.
If not specified, pipeline name + timestamp will be used.
Expand Down Expand Up @@ -277,7 +280,7 @@ def __init__(
),
}

if _VALID_AR_URL.match(template_path):
if _VALID_AR_URL.match(template_path) or _VALID_HTTPS_URL.match(template_path):
pipeline_job_args["template_uri"] = template_path

self._gca_resource = gca_pipeline_job.PipelineJob(**pipeline_job_args)
Expand Down
69 changes: 42 additions & 27 deletions google/cloud/aiplatform/utils/yaml_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import re
from types import ModuleType
from typing import Any, Dict, Optional
from urllib import request

from google.auth import credentials as auth_credentials
from google.auth import transport
from google.cloud import storage
from google.cloud.aiplatform.constants import pipeline as pipeline_constants

# Pattern for an Artifact Registry URL.
_VALID_AR_URL = re.compile(r"^https:\/\/([\w-]+)-kfp\.pkg\.dev\/.*")
_VALID_AR_URL = pipeline_constants._VALID_AR_URL

# Pattern for any JSON or YAML file over HTTPS.
_VALID_HTTPS_URL = pipeline_constants._VALID_HTTPS_URL


def load_yaml(
Expand All @@ -36,8 +39,8 @@ def load_yaml(
Args:
path (str):
Required. The path of the YAML document in Google Cloud Storage or
local.
Required. The path of the YAML document. It can be a local path, a
Google Cloud Storage URI, an Artifact Registry URI, or an HTTPS URI.
project (str):
Optional. Project to initiate the Storage client with.
credentials (auth_credentials.Credentials):
Expand All @@ -50,10 +53,25 @@ def load_yaml(
return _load_yaml_from_gs_uri(path, project, credentials)
elif _VALID_AR_URL.match(path):
return _load_yaml_from_ar_uri(path, credentials)
elif _VALID_HTTPS_URL.match(path):
return _load_yaml_from_https_uri(path)
else:
return _load_yaml_from_local_file(path)


def _maybe_import_yaml() -> ModuleType:
"""Tries to import the PyYAML module."""
try:
import yaml
except ImportError:
raise ImportError(
"PyYAML is not installed and is required to parse PipelineJob or "
'PipelineSpec files. Please install the SDK using "pip install '
'google-cloud-aiplatform[pipelines]"'
)
return yaml


def _load_yaml_from_gs_uri(
uri: str,
project: Optional[str] = None,
Expand All @@ -72,13 +90,7 @@ def _load_yaml_from_gs_uri(
Returns:
A Dict object representing the YAML document.
"""
try:
import yaml
except ImportError:
raise ImportError(
"pyyaml is not installed and is required to parse PipelineJob or PipelineSpec files. "
'Please install the SDK using "pip install google-cloud-aiplatform[pipelines]"'
)
yaml = _maybe_import_yaml()
storage_client = storage.Client(project=project, credentials=credentials)
blob = storage.Blob.from_string(uri, storage_client)
return yaml.safe_load(blob.download_as_bytes())
Expand All @@ -94,13 +106,7 @@ def _load_yaml_from_local_file(file_path: str) -> Dict[str, Any]:
Returns:
A Dict object representing the YAML document.
"""
try:
import yaml
except ImportError:
raise ImportError(
"pyyaml is not installed and is required to parse PipelineJob or PipelineSpec files. "
'Please install the SDK using "pip install google-cloud-aiplatform[pipelines]"'
)
yaml = _maybe_import_yaml()
with open(file_path) as f:
return yaml.safe_load(f)

Expand All @@ -112,21 +118,15 @@ def _load_yaml_from_ar_uri(
"""Loads data from a YAML document referenced by a Artifact Registry URI.
Args:
path (str):
uri (str):
Required. Artifact Registry URI for YAML document.
credentials (auth_credentials.Credentials):
Optional. Credentials to use with Artifact Registry.
Returns:
A Dict object representing the YAML document.
"""
try:
import yaml
except ImportError:
raise ImportError(
"pyyaml is not installed and is required to parse PipelineJob or PipelineSpec files. "
'Please install the SDK using "pip install google-cloud-aiplatform[pipelines]"'
)
yaml = _maybe_import_yaml()
req = request.Request(uri)

if credentials:
Expand All @@ -137,3 +137,18 @@ def _load_yaml_from_ar_uri(
response = request.urlopen(req)

return yaml.safe_load(response.read().decode("utf-8"))


def _load_yaml_from_https_uri(uri: str) -> Dict[str, Any]:
"""Loads data from a YAML document referenced by an HTTPS URI.
Args:
uri (str):
Required. HTTPS URI for YAML document.
Returns:
A Dict object representing the YAML document.
"""
yaml = _maybe_import_yaml()
response = request.urlopen(uri)
return yaml.safe_load(response.read().decode("utf-8"))
83 changes: 83 additions & 0 deletions tests/unit/aiplatform/test_pipeline_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@

_TEST_TEMPLATE_PATH = f"gs://{_TEST_GCS_BUCKET_NAME}/job_spec.json"
_TEST_AR_TEMPLATE_PATH = "https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"
_TEST_HTTPS_TEMPLATE_PATH = "https://raw.githubusercontent.com/repo/pipeline.json"
_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}"
_TEST_NETWORK = f"projects/{_TEST_PROJECT}/global/networks/{_TEST_PIPELINE_JOB_ID}"

Expand Down Expand Up @@ -627,6 +628,88 @@ def test_run_call_pipeline_service_create_artifact_registry(
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
)

@pytest.mark.parametrize(
"job_spec",
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],
)
@pytest.mark.parametrize("sync", [True, False])
def test_run_call_pipeline_service_create_https(
self,
mock_pipeline_service_create,
mock_pipeline_service_get,
mock_pipeline_bucket_exists,
mock_request_urlopen,
job_spec,
mock_load_yaml_and_json,
sync,
):
import yaml

aiplatform.init(
project=_TEST_PROJECT,
staging_bucket=_TEST_GCS_BUCKET_NAME,
location=_TEST_LOCATION,
credentials=_TEST_CREDENTIALS,
)

job = pipeline_jobs.PipelineJob(
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
template_path=_TEST_HTTPS_TEMPLATE_PATH,
job_id=_TEST_PIPELINE_JOB_ID,
parameter_values=_TEST_PIPELINE_PARAMETER_VALUES,
enable_caching=True,
)

job.run(
service_account=_TEST_SERVICE_ACCOUNT,
network=_TEST_NETWORK,
sync=sync,
create_request_timeout=None,
)

if not sync:
job.wait()

expected_runtime_config_dict = {
"gcsOutputDirectory": _TEST_GCS_BUCKET_NAME,
"parameterValues": _TEST_PIPELINE_PARAMETER_VALUES,
}
runtime_config = gca_pipeline_job.PipelineJob.RuntimeConfig()._pb
json_format.ParseDict(expected_runtime_config_dict, runtime_config)

job_spec = yaml.safe_load(job_spec)
pipeline_spec = job_spec.get("pipelineSpec") or job_spec

# Construct expected request
expected_gapic_pipeline_job = gca_pipeline_job.PipelineJob(
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
pipeline_spec={
"components": {},
"pipelineInfo": pipeline_spec["pipelineInfo"],
"root": pipeline_spec["root"],
"schemaVersion": "2.1.0",
},
runtime_config=runtime_config,
service_account=_TEST_SERVICE_ACCOUNT,
network=_TEST_NETWORK,
template_uri=_TEST_HTTPS_TEMPLATE_PATH,
)

mock_pipeline_service_create.assert_called_once_with(
parent=_TEST_PARENT,
pipeline_job=expected_gapic_pipeline_job,
pipeline_job_id=_TEST_PIPELINE_JOB_ID,
timeout=None,
)

mock_pipeline_service_get.assert_called_with(
name=_TEST_PIPELINE_JOB_NAME, retry=base._DEFAULT_RETRY
)

assert job._gca_resource == make_pipeline_job(
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
)

@pytest.mark.parametrize(
"job_spec",
[
Expand Down
39 changes: 33 additions & 6 deletions tests/unit/aiplatform/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from typing import Callable, Dict, Optional
from unittest import mock
from unittest.mock import patch
from urllib import request
from urllib import request as urllib_request

import pytest
import yaml
Expand Down Expand Up @@ -751,15 +751,15 @@ def json_file(tmp_path):


@pytest.fixture(scope="function")
def mock_request_urlopen():
def mock_request_urlopen(request: str) -> str:
data = {"key": "val", "list": ["1", 2, 3.0]}
with mock.patch.object(request, "urlopen") as mock_urlopen:
with mock.patch.object(urllib_request, "urlopen") as mock_urlopen:
mock_read_response = mock.MagicMock()
mock_decode_response = mock.MagicMock()
mock_decode_response.return_value = json.dumps(data)
mock_read_response.return_value.decode = mock_decode_response
mock_urlopen.return_value.read = mock_read_response
yield "https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"
yield request.param


class TestYamlUtils:
Expand All @@ -773,11 +773,38 @@ def test_load_yaml_from_local_file__with_json(self, json_file):
expected = {"key": "val", "list": ["1", 2, 3.0]}
assert actual == expected

@pytest.mark.parametrize(
"mock_request_urlopen",
["https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"],
indirect=True,
)
def test_load_yaml_from_ar_uri(self, mock_request_urlopen):
actual = yaml_utils.load_yaml(mock_request_urlopen)
expected = {"key": "val", "list": ["1", 2, 3.0]}
assert actual == expected

def test_load_yaml_from_invalid_uri(self):
@pytest.mark.parametrize(
"mock_request_urlopen",
[
"https://raw.githubusercontent.com/repo/pipeline.json",
"https://raw.githubusercontent.com/repo/pipeline.yaml",
"https://raw.githubusercontent.com/repo/pipeline.yml",
],
indirect=True,
)
def test_load_yaml_from_https_uri(self, mock_request_urlopen):
actual = yaml_utils.load_yaml(mock_request_urlopen)
expected = {"key": "val", "list": ["1", 2, 3.0]}
assert actual == expected

@pytest.mark.parametrize(
"uri",
[
"https://us-docker.pkg.dev/v2/proj/repo/img/tags/list",
"https://example.com/pipeline.exe",
"http://example.com/pipeline.yaml",
],
)
def test_load_yaml_from_invalid_uri(self, uri: str):
with pytest.raises(FileNotFoundError):
yaml_utils.load_yaml("https://us-docker.pkg.dev/v2/proj/repo/img/tags/list")
yaml_utils.load_yaml(uri)

0 comments on commit 45e0d9f

Please sign in to comment.