Skip to content

Commit

Permalink
chore: Add scheduled pipelines client system test.
Browse files Browse the repository at this point in the history
fix: Remove Schedule read mask because ListSchedules does not support it.

PiperOrigin-RevId: 539819661
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Jun 13, 2023
1 parent 69c5f60 commit 1fda417
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 98 deletions.
15 changes: 0 additions & 15 deletions google/cloud/aiplatform/preview/constants/schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,3 @@

# Pattern for any JSON or YAML file over HTTPS.
_VALID_HTTPS_URL = pipeline_constants._VALID_HTTPS_URL

# Fields to include in returned PipelineJobSchedule when enable_simple_view=True in PipelineJobSchedule.list()
_PIPELINE_JOB_SCHEDULE_READ_MASK_FIELDS = [
"name",
"display_name",
"start_time",
"end_time",
"max_run_count",
"started_run_count",
"state",
"create_time",
"update_time",
"cron",
"catch_up",
]
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,6 @@

_SCHEDULE_ERROR_STATES = schedule_constants._SCHEDULE_ERROR_STATES

_READ_MASK_FIELDS = schedule_constants._PIPELINE_JOB_SCHEDULE_READ_MASK_FIELDS


class PipelineJobSchedule(
_Schedule,
Expand Down Expand Up @@ -264,7 +262,6 @@ def list(
cls,
filter: Optional[str] = None,
order_by: Optional[str] = None,
enable_simple_view: bool = True,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
Expand All @@ -286,16 +283,6 @@ def list(
Optional. A comma-separated list of fields to order by, sorted in
ascending order. Use "desc" after a field name for descending.
Supported fields: `display_name`, `create_time`, `update_time`
enable_simple_view (bool):
Optional. Whether to pass the `read_mask` parameter to the list call.
Defaults to False if not provided. This will improve the performance of calling
list(). However, the returned PipelineJobSchedule list will not include all fields for
each PipelineJobSchedule. Setting this to True will exclude the following fields in your
response: 'create_pipeline_job_request', 'next_run_time', 'last_pause_time',
'last_resume_time', 'max_concurrent_run_count', 'allow_queueing','last_scheduled_run_response'.
The following fields will be included in each PipelineJobSchedule resource in your
response: 'name', 'display_name', 'start_time', 'end_time', 'max_run_count',
'started_run_count', 'state', 'create_time', 'update_time', 'cron', 'catch_up'.
project (str):
Optional. Project to retrieve list from. If not set, project
set in aiplatform.init will be used.
Expand All @@ -309,19 +296,9 @@ def list(
Returns:
List[PipelineJobSchedule] - A list of PipelineJobSchedule resource objects.
"""

read_mask_fields = None

if enable_simple_view:
read_mask_fields = field_mask.FieldMask(paths=_READ_MASK_FIELDS)
_LOGGER.warn(
"By enabling simple view, the PipelineJobSchedule resources returned from this method will not contain all fields."
)

return cls._list_with_local_order(
filter=filter,
order_by=order_by,
read_mask=read_mask_fields,
project=project,
location=location,
credentials=credentials,
Expand Down
98 changes: 98 additions & 0 deletions tests/system/aiplatform/test_pipeline_job_schedule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# -*- coding: utf-8 -*-

# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from google.cloud import aiplatform
from google.cloud.aiplatform.compat.types import schedule_v1beta1 as gca_schedule
from google.cloud.aiplatform.preview.pipelinejobschedule import pipeline_job_schedules
from tests.system.aiplatform import e2e_base

from kfp import components
from kfp.v2 import compiler

import pytest
from google.protobuf.json_format import MessageToDict


@pytest.mark.usefixtures(
"tear_down_resources", "prepare_staging_bucket", "delete_staging_bucket"
)
class TestPreviewPipelineJobSchedule(e2e_base.TestEndToEnd):
_temp_prefix = "tmpvrtxsdk-e2e-pjs"

def test_create_get_list(self, shared_state):
# Components:
def train(
number_of_epochs: int,
learning_rate: float,
):
print(f"number_of_epochs={number_of_epochs}")
print(f"learning_rate={learning_rate}")

train_op = components.create_component_from_func(train)

# Pipeline:
def training_pipeline(number_of_epochs: int = 2):
train_op(
number_of_epochs=number_of_epochs,
learning_rate="0.1",
)

# Creating the pipeline job schedule.
aiplatform.init(
project=e2e_base._PROJECT,
location=e2e_base._LOCATION,
)

ir_file = "pipeline.json"
compiler.Compiler().compile(
pipeline_func=training_pipeline,
package_path=ir_file,
pipeline_name="training-pipeline",
)
job = aiplatform.PipelineJob(
template_path=ir_file,
display_name="display_name",
)

pipeline_job_schedule = pipeline_job_schedules.PipelineJobSchedule(
pipeline_job=job, display_name="pipeline_job_schedule_display_name"
)

pipeline_job_schedule.create(cron_expression="*/2 * * * *", max_run_count=2)

shared_state.setdefault("resources", []).append(pipeline_job_schedule)

pipeline_job_schedule.pause()
assert pipeline_job_schedule.state == gca_schedule.Schedule.State.PAUSED

pipeline_job_schedule.resume()
assert pipeline_job_schedule.state == gca_schedule.Schedule.State.ACTIVE

pipeline_job_schedule.wait()

list_jobs_with_read_mask = pipeline_job_schedule.list_jobs(
enable_simple_view=True
)
list_jobs_without_read_mask = pipeline_job_schedule.list_jobs()

# enable_simple_view=True should apply the `read_mask` filter to limit PipelineJob fields returned
assert "serviceAccount" in MessageToDict(
list_jobs_without_read_mask[0].gca_resource._pb
)
assert "serviceAccount" not in MessageToDict(
list_jobs_with_read_mask[0].gca_resource._pb
)
62 changes: 2 additions & 60 deletions tests/unit/aiplatform/test_pipeline_job_schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from google.cloud import aiplatform
from google.cloud.aiplatform import base
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import pipeline_jobs
from google.cloud.aiplatform.compat.services import (
pipeline_service_client,
schedule_service_client_v1beta1 as schedule_service_client,
Expand All @@ -37,21 +38,16 @@
pipeline_state_v1beta1 as gca_pipeline_state,
schedule_v1beta1 as gca_schedule,
)
from google.cloud.aiplatform.preview.constants import (
schedules as schedule_constants,
)
from google.cloud.aiplatform.preview.pipelinejob import (
pipeline_jobs as preview_pipeline_jobs,
)
from google.cloud.aiplatform import pipeline_jobs
from google.cloud.aiplatform.preview.pipelinejobschedule import (
pipeline_job_schedules,
)
from google.cloud.aiplatform.utils import gcs_utils
import pytest
import yaml

from google.protobuf import field_mask_pb2 as field_mask
from google.protobuf import json_format

_TEST_PROJECT = "test-project"
Expand Down Expand Up @@ -1129,66 +1125,12 @@ def test_list_schedules(self, mock_schedule_service_list, mock_load_yaml_and_jso
create_request_timeout=None,
)

pipeline_job_schedule.list(enable_simple_view=False)
pipeline_job_schedule.list()

mock_schedule_service_list.assert_called_once_with(
request={"parent": _TEST_PARENT}
)

@pytest.mark.usefixtures(
"mock_schedule_service_create",
"mock_schedule_service_get",
"mock_schedule_bucket_exists",
)
@pytest.mark.parametrize(
"job_spec",
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],
)
def test_list_schedules_with_read_mask(
self, mock_schedule_service_list, mock_load_yaml_and_json
):
aiplatform.init(
project=_TEST_PROJECT,
staging_bucket=_TEST_GCS_BUCKET_NAME,
location=_TEST_LOCATION,
credentials=_TEST_CREDENTIALS,
)

job = pipeline_jobs.PipelineJob(
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
template_path=_TEST_TEMPLATE_PATH,
parameter_values=_TEST_PIPELINE_PARAMETER_VALUES,
input_artifacts=_TEST_PIPELINE_INPUT_ARTIFACTS,
enable_caching=True,
)

pipeline_job_schedule = pipeline_job_schedules.PipelineJobSchedule(
pipeline_job=job,
display_name=_TEST_PIPELINE_JOB_SCHEDULE_DISPLAY_NAME,
)

pipeline_job_schedule.create(
cron_expression=_TEST_PIPELINE_JOB_SCHEDULE_CRON_EXPRESSION,
max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT,
max_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT,
service_account=_TEST_SERVICE_ACCOUNT,
network=_TEST_NETWORK,
create_request_timeout=None,
)

pipeline_job_schedule.list(enable_simple_view=True)

test_pipeline_job_schedule_list_read_mask = field_mask.FieldMask(
paths=schedule_constants._PIPELINE_JOB_SCHEDULE_READ_MASK_FIELDS
)

mock_schedule_service_list.assert_called_once_with(
request={
"parent": _TEST_PARENT,
"read_mask": test_pipeline_job_schedule_list_read_mask,
},
)

@pytest.mark.usefixtures(
"mock_schedule_service_create",
"mock_schedule_service_get",
Expand Down

0 comments on commit 1fda417

Please sign in to comment.