Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for accepting an Artifact Registry URL in pipeline_job #1405

Merged
merged 45 commits into from
Jun 17, 2022
Merged
Show file tree
Hide file tree
Changes from 42 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
5825852
Add support for Artifact Registry in template_path
chongyouquan May 9, 2022
5664826
fix typo
chongyouquan May 10, 2022
a3849b8
Merge branch 'googleapis:main' into main
chongyouquan May 10, 2022
cd805b7
Merge branch 'main' of https://github.com/chongyouquan/python-aiplatform
chongyouquan May 10, 2022
4b420f6
update tests
chongyouquan May 10, 2022
3564245
fix AR path
chongyouquan May 10, 2022
a52e996
remove unused project
chongyouquan May 10, 2022
65732b0
add code for refreshing credentials
chongyouquan May 11, 2022
8cd82d0
add import for google.auth.transport
chongyouquan May 11, 2022
68f2c15
fix AR path
chongyouquan May 11, 2022
fb7ff6c
fix AR path
chongyouquan May 11, 2022
28d76cc
fix runtime_config
chongyouquan May 11, 2022
b4db0bd
test removing v1beta1
chongyouquan May 11, 2022
2bfdb2f
try using v1 directly instead
chongyouquan May 11, 2022
e5fd26c
update to use v1beta1
chongyouquan May 11, 2022
3f03be8
use select_version
chongyouquan May 11, 2022
05fb7ff
add back template_uri
chongyouquan May 11, 2022
2e4cb80
try adding back v1beta1
chongyouquan May 11, 2022
c4483dc
use select_version
chongyouquan May 11, 2022
61a0468
differentiate when to use select_version
chongyouquan May 11, 2022
6324d3f
test removing v1beta1 for pipeline_complete_states
chongyouquan May 11, 2022
113a00a
add tests for creating pipelines using v1beta1
chongyouquan May 20, 2022
f80ecb7
Merge branch 'main' into main
chongyouquan May 20, 2022
f2649aa
fix merge
chongyouquan May 20, 2022
e1f15ce
fix typo
chongyouquan May 20, 2022
856ba47
fix lint using blacken
chongyouquan May 20, 2022
086f5f5
fix regex
chongyouquan May 23, 2022
07f4ce5
Merge branch 'main' into main
chongyouquan May 23, 2022
656eb14
Merge branch 'main' into main
chongyouquan Jun 1, 2022
f67871b
Merge branch 'main' of https://github.com/chongyouquan/python-aiplatf…
chongyouquan Jun 6, 2022
ea5d41a
update to use v1 instead of v1beta1
chongyouquan Jun 6, 2022
e940880
add test for invalid url
chongyouquan Jun 6, 2022
787ac03
update error type
chongyouquan Jun 7, 2022
d70b58f
Merge branch 'main' into artifact_registry
chongyouquan Jun 7, 2022
72cdd9e
implement failure_policy
chongyouquan Jun 9, 2022
1be1223
Merge branch 'main' of https://github.com/googleapis/python-aiplatform
chongyouquan Jun 9, 2022
e1149b8
Merge branch 'main' into artifact_registry
chongyouquan Jun 9, 2022
082f160
use urllib.request instead of requests
chongyouquan Jun 9, 2022
c031144
Revert "implement failure_policy"
chongyouquan Jun 9, 2022
d145a83
Merge branch 'main' of https://github.com/googleapis/python-aiplatform
chongyouquan Jun 9, 2022
90f4ab1
Merge branch 'main' into artifact_registry
chongyouquan Jun 9, 2022
c245a5c
fix lint
chongyouquan Jun 10, 2022
0b5d654
Merge branch 'main' into artifact_registry
chongyouquan Jun 15, 2022
b47acce
Merge branch 'main' into artifact_registry
chongyouquan Jun 16, 2022
0ec0a24
Merge branch 'main' into artifact_registry
parthea Jun 16, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions google/cloud/aiplatform/pipeline_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@
# Pattern for valid names used as a Vertex resource name.
_VALID_NAME_PATTERN = re.compile("^[a-z][-a-z0-9]{0,127}$")

# Pattern for an Artifact Registry URL.
_VALID_AR_URL = re.compile(r"^https:\/\/([\w-]+)-kfp\.pkg\.dev\/.*")


def _get_current_time() -> datetime.datetime:
"""Gets the current timestamp."""
Expand Down Expand Up @@ -111,8 +114,9 @@ def __init__(
Required. The user-defined name of this Pipeline.
template_path (str):
Required. The path of PipelineJob or PipelineSpec JSON or YAML file. It
can be a local path or a Google Cloud Storage URI.
Example: "gs://project.name"
can be a local path, a Google Cloud Storage URI (e.g. "gs://project.name"),
or an Artifact Registry URI (e.g.
"https://us-central1-kfp.pkg.dev/proj/repo/pack/latest").
job_id (str):
Optional. The unique ID of the job run.
If not specified, pipeline name + timestamp will be used.
Expand Down Expand Up @@ -223,15 +227,20 @@ def __init__(
if enable_caching is not None:
_set_enable_caching_value(pipeline_job["pipelineSpec"], enable_caching)

self._gca_resource = gca_pipeline_job.PipelineJob(
display_name=display_name,
pipeline_spec=pipeline_job["pipelineSpec"],
labels=labels,
runtime_config=runtime_config,
encryption_spec=initializer.global_config.get_encryption_spec(
pipeline_job_args = {
"display_name": display_name,
"pipeline_spec": pipeline_job["pipelineSpec"],
"labels": labels,
"runtime_config": runtime_config,
"encryption_spec": initializer.global_config.get_encryption_spec(
encryption_spec_key_name=encryption_spec_key_name
),
)
}

if _VALID_AR_URL.match(template_path):
pipeline_job_args["template_uri"] = template_path
chongyouquan marked this conversation as resolved.
Show resolved Hide resolved

self._gca_resource = gca_pipeline_job.PipelineJob(**pipeline_job_args)

@base.optional_sync()
def run(
Expand Down
42 changes: 42 additions & 0 deletions google/cloud/aiplatform/utils/yaml_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,17 @@
# limitations under the License.
#

import re
from typing import Any, Dict, Optional
from urllib import request

from google.auth import credentials as auth_credentials
from google.auth import transport
from google.cloud import storage

# Pattern for an Artifact Registry URL.
_VALID_AR_URL = re.compile(r"^https:\/\/([\w-]+)-kfp\.pkg\.dev\/.*")


def load_yaml(
path: str,
Expand All @@ -42,6 +48,8 @@ def load_yaml(
"""
if path.startswith("gs://"):
return _load_yaml_from_gs_uri(path, project, credentials)
elif _VALID_AR_URL.match(path):
return _load_yaml_from_ar_uri(path, credentials)
else:
return _load_yaml_from_local_file(path)

Expand Down Expand Up @@ -95,3 +103,37 @@ def _load_yaml_from_local_file(file_path: str) -> Dict[str, Any]:
)
with open(file_path) as f:
return yaml.safe_load(f)


def _load_yaml_from_ar_uri(
uri: str,
credentials: Optional[auth_credentials.Credentials] = None,
) -> Dict[str, Any]:
"""Loads data from a YAML document referenced by a Artifact Registry URI.

Args:
path (str):
Required. Artifact Registry URI for YAML document.
credentials (auth_credentials.Credentials):
Optional. Credentials to use with Artifact Registry.

Returns:
A Dict object representing the YAML document.
"""
try:
import yaml
except ImportError:
raise ImportError(
"pyyaml is not installed and is required to parse PipelineJob or PipelineSpec files. "
'Please install the SDK using "pip install google-cloud-aiplatform[pipelines]"'
)
req = request.Request(uri)

if credentials:
if not credentials.valid:
credentials.refresh(transport.requests.Request())
if credentials.token:
req.add_header("Authorization", "Bearer " + credentials.token)
response = request.urlopen(req)

return yaml.safe_load(response.read().decode("utf-8"))
92 changes: 92 additions & 0 deletions tests/unit/aiplatform/test_pipeline_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from unittest import mock
from importlib import reload
from unittest.mock import patch
from urllib import request
from datetime import datetime

from google.auth import credentials as auth_credentials
Expand Down Expand Up @@ -49,6 +50,7 @@
_TEST_SERVICE_ACCOUNT = "[email protected]"

_TEST_TEMPLATE_PATH = f"gs://{_TEST_GCS_BUCKET_NAME}/job_spec.json"
_TEST_AR_TEMPLATE_PATH = "https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"
chongyouquan marked this conversation as resolved.
Show resolved Hide resolved
_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}"
_TEST_NETWORK = f"projects/{_TEST_PROJECT}/global/networks/{_TEST_PIPELINE_JOB_ID}"

Expand Down Expand Up @@ -283,6 +285,17 @@ def mock_load_yaml_and_json(job_spec):
yield mock_load_yaml_and_json


@pytest.fixture
def mock_request_urlopen(job_spec):
with patch.object(request, "urlopen") as mock_urlopen:
mock_read_response = mock.MagicMock()
mock_decode_response = mock.MagicMock()
mock_decode_response.return_value = job_spec.encode()
mock_read_response.return_value.decode = mock_decode_response
mock_urlopen.return_value.read = mock_read_response
yield mock_urlopen


@pytest.mark.usefixtures("google_auth_mock")
class TestPipelineJob:
class FakePipelineJob(pipeline_jobs.PipelineJob):
Expand Down Expand Up @@ -379,6 +392,85 @@ def test_run_call_pipeline_service_create(
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
)

@pytest.mark.parametrize(
"job_spec",
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],
)
@pytest.mark.parametrize("sync", [True, False])
def test_run_call_pipeline_service_create_artifact_registry(
self,
mock_pipeline_service_create,
mock_pipeline_service_get,
mock_request_urlopen,
job_spec,
mock_load_yaml_and_json,
sync,
):
aiplatform.init(
project=_TEST_PROJECT,
staging_bucket=_TEST_GCS_BUCKET_NAME,
location=_TEST_LOCATION,
credentials=_TEST_CREDENTIALS,
)

job = pipeline_jobs.PipelineJob(
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
template_path=_TEST_AR_TEMPLATE_PATH,
job_id=_TEST_PIPELINE_JOB_ID,
parameter_values=_TEST_PIPELINE_PARAMETER_VALUES,
enable_caching=True,
)

job.run(
service_account=_TEST_SERVICE_ACCOUNT,
network=_TEST_NETWORK,
sync=sync,
create_request_timeout=None,
)

if not sync:
job.wait()

expected_runtime_config_dict = {
"gcsOutputDirectory": _TEST_GCS_BUCKET_NAME,
"parameterValues": _TEST_PIPELINE_PARAMETER_VALUES,
}
runtime_config = gca_pipeline_job.PipelineJob.RuntimeConfig()._pb
json_format.ParseDict(expected_runtime_config_dict, runtime_config)

job_spec = yaml.safe_load(job_spec)
pipeline_spec = job_spec.get("pipelineSpec") or job_spec

# Construct expected request
expected_gapic_pipeline_job = gca_pipeline_job.PipelineJob(
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
pipeline_spec={
"components": {},
"pipelineInfo": pipeline_spec["pipelineInfo"],
"root": pipeline_spec["root"],
"schemaVersion": "2.1.0",
},
runtime_config=runtime_config,
service_account=_TEST_SERVICE_ACCOUNT,
network=_TEST_NETWORK,
template_uri=_TEST_AR_TEMPLATE_PATH,
)

mock_pipeline_service_create.assert_called_once_with(
parent=_TEST_PARENT,
pipeline_job=expected_gapic_pipeline_job,
pipeline_job_id=_TEST_PIPELINE_JOB_ID,
timeout=None,
)

mock_pipeline_service_get.assert_called_with(
name=_TEST_PIPELINE_JOB_NAME, retry=base._DEFAULT_RETRY
)

assert job._gca_resource == make_pipeline_job(
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
)

@pytest.mark.parametrize(
"job_spec",
[
Expand Down
27 changes: 25 additions & 2 deletions tests/unit/aiplatform/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import json
import os
from typing import Callable, Dict, Optional
from unittest import mock
from urllib import request

import pytest
import yaml
Expand Down Expand Up @@ -560,13 +562,34 @@ def json_file(tmp_path):
yield json_file_path


@pytest.fixture(scope="function")
def mock_request_urlopen():
data = {"key": "val", "list": ["1", 2, 3.0]}
with mock.patch.object(request, "urlopen") as mock_urlopen:
mock_read_response = mock.MagicMock()
mock_decode_response = mock.MagicMock()
mock_decode_response.return_value = json.dumps(data)
mock_read_response.return_value.decode = mock_decode_response
mock_urlopen.return_value.read = mock_read_response
yield "https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"


class TestYamlUtils:
def test_load_yaml_from_local_file__with_json(self, yaml_file):
def test_load_yaml_from_local_file__with_yaml(self, yaml_file):
actual = yaml_utils.load_yaml(yaml_file)
expected = {"key": "val", "list": ["1", 2, 3.0]}
assert actual == expected

def test_load_yaml_from_local_file__with_yaml(self, json_file):
def test_load_yaml_from_local_file__with_json(self, json_file):
actual = yaml_utils.load_yaml(json_file)
expected = {"key": "val", "list": ["1", 2, 3.0]}
assert actual == expected

def test_load_yaml_from_ar_uri(self, mock_request_urlopen):
actual = yaml_utils.load_yaml(mock_request_urlopen)
expected = {"key": "val", "list": ["1", 2, 3.0]}
assert actual == expected

def test_load_yaml_from_invalid_uri(self):
with pytest.raises(FileNotFoundError):
yaml_utils.load_yaml("https://us-docker.pkg.dev/v2/proj/repo/img/tags/list")