Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for HTTPS URI pipeline templates #1683

Merged
merged 2 commits into from
Sep 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions google/cloud/aiplatform/constants/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,13 @@
_PIPELINE_ERROR_STATES = set([gca_pipeline_state.PipelineState.PIPELINE_STATE_FAILED])

# Pattern for valid names used as a Vertex resource name.
_VALID_NAME_PATTERN = re.compile("^[a-z][-a-z0-9]{0,127}$")
_VALID_NAME_PATTERN = re.compile("^[a-z][-a-z0-9]{0,127}$", re.IGNORECASE)

# Pattern for an Artifact Registry URL.
_VALID_AR_URL = re.compile(r"^https:\/\/([\w-]+)-kfp\.pkg\.dev\/.*")
_VALID_AR_URL = re.compile(r"^https:\/\/([\w-]+)-kfp\.pkg\.dev\/.*", re.IGNORECASE)

# Pattern for any JSON or YAML file over HTTPS.
_VALID_HTTPS_URL = re.compile(r"^https:\/\/([\.\/\w-]+)\/.*(json|yaml|yml)$")
TheMichaelHu marked this conversation as resolved.
Show resolved Hide resolved

# Fields to include in returned PipelineJob when enable_simple_view=True in PipelineJob.list()
_READ_MASK_FIELDS = [
Expand Down
9 changes: 6 additions & 3 deletions google/cloud/aiplatform/pipeline_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@
# Pattern for an Artifact Registry URL.
_VALID_AR_URL = pipeline_constants._VALID_AR_URL

# Pattern for any JSON or YAML file over HTTPS.
_VALID_HTTPS_URL = pipeline_constants._VALID_HTTPS_URL

_READ_MASK_FIELDS = pipeline_constants._READ_MASK_FIELDS


Expand Down Expand Up @@ -131,8 +134,8 @@ def __init__(
template_path (str):
Required. The path of PipelineJob or PipelineSpec JSON or YAML file. It
can be a local path, a Google Cloud Storage URI (e.g. "gs://project.name"),
or an Artifact Registry URI (e.g.
"https://us-central1-kfp.pkg.dev/proj/repo/pack/latest").
an Artifact Registry URI (e.g.
"https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"), or an HTTPS URI.
job_id (str):
Optional. The unique ID of the job run.
If not specified, pipeline name + timestamp will be used.
Expand Down Expand Up @@ -277,7 +280,7 @@ def __init__(
),
}

if _VALID_AR_URL.match(template_path):
if _VALID_AR_URL.match(template_path) or _VALID_HTTPS_URL.match(template_path):
pipeline_job_args["template_uri"] = template_path

self._gca_resource = gca_pipeline_job.PipelineJob(**pipeline_job_args)
Expand Down
64 changes: 34 additions & 30 deletions google/cloud/aiplatform/utils/yaml_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import re
from types import ModuleType
from typing import Any, Dict, Optional
from urllib import request

from google.auth import credentials as auth_credentials
from google.auth import transport
from google.cloud import storage
from google.cloud.aiplatform.constants import pipeline as pipeline_constants

# Pattern for an Artifact Registry URL.
_VALID_AR_URL = re.compile(r"^https:\/\/([\w-]+)-kfp\.pkg\.dev\/.*")
_VALID_AR_URL = pipeline_constants._VALID_AR_URL

# Pattern for any JSON or YAML file over HTTPS.
_VALID_HTTPS_URL = pipeline_constants._VALID_HTTPS_URL


def load_yaml(
Expand All @@ -36,8 +39,8 @@ def load_yaml(

Args:
path (str):
Required. The path of the YAML document in Google Cloud Storage or
local.
Required. The path of the YAML document. It can be a local path, a
Google Cloud Storage URI, an Artifact Registry URI, or an HTTPS URI.
project (str):
Optional. Project to initiate the Storage client with.
credentials (auth_credentials.Credentials):
Expand All @@ -48,12 +51,31 @@ def load_yaml(
"""
if path.startswith("gs://"):
return _load_yaml_from_gs_uri(path, project, credentials)
elif _VALID_AR_URL.match(path):
return _load_yaml_from_ar_uri(path, credentials)
elif path.startswith("http://") or path.startswith("https://"):
if _VALID_AR_URL.match(path) or _VALID_HTTPS_URL.match(path):
return _load_yaml_from_https_uri(path, credentials)
else:
raise ValueError(
"Invalid HTTPS URI. If not using Artifact Registry, please "
"ensure the URI ends with .json, .yaml, or .yml."
)
else:
return _load_yaml_from_local_file(path)


def _maybe_import_yaml() -> ModuleType:
"""Tries to import the PyYAML module."""
try:
import yaml
except ImportError:
raise ImportError(
"PyYAML is not installed and is required to parse PipelineJob or "
'PipelineSpec files. Please install the SDK using "pip install '
'google-cloud-aiplatform[pipelines]"'
)
return yaml


def _load_yaml_from_gs_uri(
uri: str,
project: Optional[str] = None,
Expand All @@ -72,13 +94,7 @@ def _load_yaml_from_gs_uri(
Returns:
A Dict object representing the YAML document.
"""
try:
import yaml
except ImportError:
raise ImportError(
"pyyaml is not installed and is required to parse PipelineJob or PipelineSpec files. "
'Please install the SDK using "pip install google-cloud-aiplatform[pipelines]"'
)
yaml = _maybe_import_yaml()
storage_client = storage.Client(project=project, credentials=credentials)
blob = storage.Blob.from_string(uri, storage_client)
return yaml.safe_load(blob.download_as_bytes())
Expand All @@ -94,39 +110,27 @@ def _load_yaml_from_local_file(file_path: str) -> Dict[str, Any]:
Returns:
A Dict object representing the YAML document.
"""
try:
import yaml
except ImportError:
raise ImportError(
"pyyaml is not installed and is required to parse PipelineJob or PipelineSpec files. "
'Please install the SDK using "pip install google-cloud-aiplatform[pipelines]"'
)
yaml = _maybe_import_yaml()
with open(file_path) as f:
return yaml.safe_load(f)


def _load_yaml_from_ar_uri(
def _load_yaml_from_https_uri(
uri: str,
credentials: Optional[auth_credentials.Credentials] = None,
) -> Dict[str, Any]:
"""Loads data from a YAML document referenced by a Artifact Registry URI.

Args:
path (str):
uri (str):
Required. Artifact Registry URI for YAML document.
credentials (auth_credentials.Credentials):
Optional. Credentials to use with Artifact Registry.

Returns:
A Dict object representing the YAML document.
"""
try:
import yaml
except ImportError:
raise ImportError(
"pyyaml is not installed and is required to parse PipelineJob or PipelineSpec files. "
'Please install the SDK using "pip install google-cloud-aiplatform[pipelines]"'
)
yaml = _maybe_import_yaml()
req = request.Request(uri)

if credentials:
Expand Down
83 changes: 83 additions & 0 deletions tests/unit/aiplatform/test_pipeline_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@

_TEST_TEMPLATE_PATH = f"gs://{_TEST_GCS_BUCKET_NAME}/job_spec.json"
_TEST_AR_TEMPLATE_PATH = "https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"
_TEST_HTTPS_TEMPLATE_PATH = "https://raw.githubusercontent.com/repo/pipeline.json"
_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}"
_TEST_NETWORK = f"projects/{_TEST_PROJECT}/global/networks/{_TEST_PIPELINE_JOB_ID}"

Expand Down Expand Up @@ -627,6 +628,88 @@ def test_run_call_pipeline_service_create_artifact_registry(
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
)

@pytest.mark.parametrize(
"job_spec",
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],
)
@pytest.mark.parametrize("sync", [True, False])
def test_run_call_pipeline_service_create_https(
self,
mock_pipeline_service_create,
mock_pipeline_service_get,
mock_pipeline_bucket_exists,
mock_request_urlopen,
job_spec,
mock_load_yaml_and_json,
sync,
):
import yaml

aiplatform.init(
project=_TEST_PROJECT,
staging_bucket=_TEST_GCS_BUCKET_NAME,
location=_TEST_LOCATION,
credentials=_TEST_CREDENTIALS,
)

job = pipeline_jobs.PipelineJob(
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
template_path=_TEST_HTTPS_TEMPLATE_PATH,
job_id=_TEST_PIPELINE_JOB_ID,
parameter_values=_TEST_PIPELINE_PARAMETER_VALUES,
enable_caching=True,
)

job.run(
service_account=_TEST_SERVICE_ACCOUNT,
network=_TEST_NETWORK,
sync=sync,
create_request_timeout=None,
)

if not sync:
job.wait()

expected_runtime_config_dict = {
"gcsOutputDirectory": _TEST_GCS_BUCKET_NAME,
"parameterValues": _TEST_PIPELINE_PARAMETER_VALUES,
}
runtime_config = gca_pipeline_job.PipelineJob.RuntimeConfig()._pb
json_format.ParseDict(expected_runtime_config_dict, runtime_config)

job_spec = yaml.safe_load(job_spec)
pipeline_spec = job_spec.get("pipelineSpec") or job_spec

# Construct expected request
expected_gapic_pipeline_job = gca_pipeline_job.PipelineJob(
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
pipeline_spec={
"components": {},
"pipelineInfo": pipeline_spec["pipelineInfo"],
"root": pipeline_spec["root"],
"schemaVersion": "2.1.0",
},
runtime_config=runtime_config,
service_account=_TEST_SERVICE_ACCOUNT,
network=_TEST_NETWORK,
template_uri=_TEST_HTTPS_TEMPLATE_PATH,
)

mock_pipeline_service_create.assert_called_once_with(
parent=_TEST_PARENT,
pipeline_job=expected_gapic_pipeline_job,
pipeline_job_id=_TEST_PIPELINE_JOB_ID,
timeout=None,
)

mock_pipeline_service_get.assert_called_with(
name=_TEST_PIPELINE_JOB_NAME, retry=base._DEFAULT_RETRY
)

assert job._gca_resource == make_pipeline_job(
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
)

@pytest.mark.parametrize(
"job_spec",
[
Expand Down
45 changes: 38 additions & 7 deletions tests/unit/aiplatform/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from typing import Callable, Dict, Optional
from unittest import mock
from unittest.mock import patch
from urllib import request
from urllib import request as urllib_request

import pytest
import yaml
Expand Down Expand Up @@ -751,15 +751,15 @@ def json_file(tmp_path):


@pytest.fixture(scope="function")
def mock_request_urlopen():
def mock_request_urlopen(request: str) -> str:
data = {"key": "val", "list": ["1", 2, 3.0]}
with mock.patch.object(request, "urlopen") as mock_urlopen:
with mock.patch.object(urllib_request, "urlopen") as mock_urlopen:
mock_read_response = mock.MagicMock()
mock_decode_response = mock.MagicMock()
mock_decode_response.return_value = json.dumps(data)
mock_read_response.return_value.decode = mock_decode_response
mock_urlopen.return_value.read = mock_read_response
yield "https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"
yield request.param


class TestYamlUtils:
Expand All @@ -773,11 +773,42 @@ def test_load_yaml_from_local_file__with_json(self, json_file):
expected = {"key": "val", "list": ["1", 2, 3.0]}
assert actual == expected

@pytest.mark.parametrize(
"mock_request_urlopen",
["https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"],
indirect=True,
)
def test_load_yaml_from_ar_uri(self, mock_request_urlopen):
actual = yaml_utils.load_yaml(mock_request_urlopen)
expected = {"key": "val", "list": ["1", 2, 3.0]}
assert actual == expected

def test_load_yaml_from_invalid_uri(self):
with pytest.raises(FileNotFoundError):
yaml_utils.load_yaml("https://us-docker.pkg.dev/v2/proj/repo/img/tags/list")
@pytest.mark.parametrize(
"mock_request_urlopen",
[
"https://raw.githubusercontent.com/repo/pipeline.json",
"https://raw.githubusercontent.com/repo/pipeline.yaml",
"https://raw.githubusercontent.com/repo/pipeline.yml",
],
indirect=True,
)
def test_load_yaml_from_https_uri(self, mock_request_urlopen):
actual = yaml_utils.load_yaml(mock_request_urlopen)
expected = {"key": "val", "list": ["1", 2, 3.0]}
assert actual == expected

@pytest.mark.parametrize(
"uri",
[
"https://us-docker.pkg.dev/v2/proj/repo/img/tags/list",
"https://example.com/pipeline.exe",
"http://example.com/pipeline.yaml",
],
)
def test_load_yaml_from_invalid_uri(self, uri: str):
message = (
"Invalid HTTPS URI. If not using Artifact Registry, please "
"ensure the URI ends with .json, .yaml, or .yml."
)
with pytest.raises(ValueError, match=message):
yaml_utils.load_yaml(uri)