Skip to content

Commit

Permalink
feat: Add PipelineJobSchedule update method and unit tests.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 539259661
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Jun 10, 2023
1 parent 50646be commit 69c5f60
Show file tree
Hide file tree
Showing 2 changed files with 216 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
)
from google.protobuf import field_mask_pb2 as field_mask


_LOGGER = base.Logger(__name__)

# Pattern for valid names used as a Vertex resource name.
Expand All @@ -53,6 +52,8 @@
# Pattern for any JSON or YAML file over HTTPS.
_VALID_HTTPS_URL = schedule_constants._VALID_HTTPS_URL

_SCHEDULE_ERROR_STATES = schedule_constants._SCHEDULE_ERROR_STATES

_READ_MASK_FIELDS = schedule_constants._PIPELINE_JOB_SCHEDULE_READ_MASK_FIELDS


Expand Down Expand Up @@ -385,3 +386,86 @@ def list_jobs(
location=location,
credentials=credentials,
)

def update(
self,
display_name: Optional[str] = None,
cron_expression: Optional[str] = None,
start_time: Optional[str] = None,
end_time: Optional[str] = None,
allow_queueing: Optional[bool] = None,
max_run_count: Optional[int] = None,
max_concurrent_run_count: Optional[int] = None,
) -> None:
"""Update an existing PipelineJobSchedule.
Example usage:
pipeline_job_schedule.update(
display_name='updated-display-name',
cron_expression='1 2 3 4 5',
)
Args:
display_name (str):
Optional. The user-defined name of this PipelineJobSchedule.
cron_expression (str):
Optional. Time specification (cron schedule expression) to launch scheduled runs.
To explicitly set a timezone to the cron tab, apply a prefix: "CRON_TZ=${IANA_TIME_ZONE}" or "TZ=${IANA_TIME_ZONE}".
The ${IANA_TIME_ZONE} may only be a valid string from IANA time zone database.
For example, "CRON_TZ=America/New_York 1 * * * *", or "TZ=America/New_York 1 * * * *".
start_time (str):
Optional. Timestamp after which the first run can be scheduled.
If unspecified, it defaults to the schedule creation timestamp.
end_time (str):
Optional. Timestamp after which no more runs will be scheduled.
If unspecified, then runs will be scheduled indefinitely.
allow_queueing (bool):
Optional. Whether new scheduled runs can be queued when max_concurrent_runs limit is reached.
max_run_count (int):
Optional. Maximum run count of the schedule.
If specified, The schedule will be completed when either started_run_count >= max_run_count or when end_time is reached.
max_concurrent_run_count (int):
Optional. Maximum number of runs that can be started concurrently for this PipelineJobSchedule.
Raises:
RuntimeError: User tried to call update() before create().
"""
pipeline_job_schedule = self._gca_resource
if pipeline_job_schedule.state in _SCHEDULE_ERROR_STATES:
raise RuntimeError(
"Not updating PipelineJobSchedule: PipelineJobSchedule must be active or completed."
)

updated_fields = []
if display_name is not None:
updated_fields.append("display_name")
setattr(pipeline_job_schedule, "display_name", display_name)
if cron_expression is not None:
updated_fields.append("cron")
setattr(pipeline_job_schedule, "cron", cron_expression)
if start_time is not None:
updated_fields.append("start_time")
setattr(pipeline_job_schedule, "start_time", start_time)
if end_time is not None:
updated_fields.append("end_time")
setattr(pipeline_job_schedule, "end_time", end_time)
if allow_queueing is not None:
updated_fields.append("allow_queueing")
setattr(pipeline_job_schedule, "allow_queueing", allow_queueing)
if max_run_count is not None:
updated_fields.append("max_run_count")
setattr(pipeline_job_schedule, "max_run_count", max_run_count)
if max_concurrent_run_count is not None:
updated_fields.append("max_concurrent_run_count")
setattr(
pipeline_job_schedule,
"max_concurrent_run_count",
max_concurrent_run_count,
)

update_mask = field_mask.FieldMask(paths=updated_fields)
self.api_client.update_schedule(
schedule=pipeline_job_schedule,
update_mask=update_mask,
)
131 changes: 131 additions & 0 deletions tests/unit/aiplatform/test_pipeline_job_schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@
_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT = 1
_TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT = 2

_TEST_UPDATED_PIPELINE_JOB_SCHEDULE_CRON_EXPRESSION = "1 1 1 1 1"
_TEST_UPDATED_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT = 5

_TEST_TEMPLATE_PATH = f"gs://{_TEST_GCS_BUCKET_NAME}/job_spec.json"
_TEST_AR_TEMPLATE_PATH = "https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"
_TEST_HTTPS_TEMPLATE_PATH = "https://raw.githubusercontent.com/repo/pipeline.json"
Expand Down Expand Up @@ -371,6 +374,23 @@ def mock_pipeline_service_list():
yield mock_list_pipeline_jobs


@pytest.fixture
def mock_schedule_service_update():
with mock.patch.object(
schedule_service_client.ScheduleServiceClient, "update_schedule"
) as mock_update_schedule:
mock_update_schedule.return_value = gca_schedule.Schedule(
name=_TEST_PIPELINE_JOB_SCHEDULE_DISPLAY_NAME,
state=gca_schedule.Schedule.State.COMPLETED,
create_time=_TEST_PIPELINE_CREATE_TIME,
cron=_TEST_UPDATED_PIPELINE_JOB_SCHEDULE_CRON_EXPRESSION,
max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT,
max_run_count=_TEST_UPDATED_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT,
create_pipeline_job_request=_TEST_CREATE_PIPELINE_JOB_REQUEST,
)
yield mock_update_schedule


@pytest.fixture
def mock_load_yaml_and_json(job_spec):
with patch.object(storage.Blob, "download_as_bytes") as mock_load_yaml_and_json:
Expand Down Expand Up @@ -1304,3 +1324,114 @@ def test_resume_pipeline_job_schedule_without_created(
pipeline_job_schedule.resume()

assert e.match(regexp=r"Schedule resource has not been created")

@pytest.mark.parametrize(
"job_spec",
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],
)
def test_call_schedule_service_update(
self,
mock_schedule_service_create,
mock_schedule_service_update,
mock_schedule_service_get,
mock_schedule_bucket_exists,
job_spec,
mock_load_yaml_and_json,
):
"""Updates a PipelineJobSchedule.
Updates cron_expression and max_run_count.
"""
aiplatform.init(
project=_TEST_PROJECT,
staging_bucket=_TEST_GCS_BUCKET_NAME,
location=_TEST_LOCATION,
credentials=_TEST_CREDENTIALS,
)

job = pipeline_jobs.PipelineJob(
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
template_path=_TEST_TEMPLATE_PATH,
parameter_values=_TEST_PIPELINE_PARAMETER_VALUES,
input_artifacts=_TEST_PIPELINE_INPUT_ARTIFACTS,
enable_caching=True,
)

pipeline_job_schedule = pipeline_job_schedules.PipelineJobSchedule(
pipeline_job=job,
display_name=_TEST_PIPELINE_JOB_SCHEDULE_DISPLAY_NAME,
)

pipeline_job_schedule.create(
cron_expression=_TEST_PIPELINE_JOB_SCHEDULE_CRON_EXPRESSION,
max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT,
max_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT,
service_account=_TEST_SERVICE_ACCOUNT,
network=_TEST_NETWORK,
create_request_timeout=None,
)

pipeline_job_schedule.update(
cron_expression=_TEST_UPDATED_PIPELINE_JOB_SCHEDULE_CRON_EXPRESSION,
max_run_count=_TEST_UPDATED_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT,
)

expected_gapic_pipeline_job_schedule = gca_schedule.Schedule(
name=_TEST_PIPELINE_JOB_SCHEDULE_NAME,
state=gca_schedule.Schedule.State.COMPLETED,
create_time=_TEST_PIPELINE_CREATE_TIME,
cron=_TEST_UPDATED_PIPELINE_JOB_SCHEDULE_CRON_EXPRESSION,
max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT,
max_run_count=_TEST_UPDATED_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT,
create_pipeline_job_request=_TEST_CREATE_PIPELINE_JOB_REQUEST,
)
assert (
pipeline_job_schedule._gca_resource == expected_gapic_pipeline_job_schedule
)

@pytest.mark.parametrize(
"job_spec",
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],
)
def test_call_schedule_service_update_before_create(
self,
mock_schedule_service_create,
mock_schedule_service_update,
mock_schedule_service_get,
mock_schedule_bucket_exists,
job_spec,
mock_load_yaml_and_json,
):
"""Updates a PipelineJobSchedule.
Raises error because PipelineJobSchedule should be created before update.
"""
aiplatform.init(
project=_TEST_PROJECT,
staging_bucket=_TEST_GCS_BUCKET_NAME,
location=_TEST_LOCATION,
credentials=_TEST_CREDENTIALS,
)

job = pipeline_jobs.PipelineJob(
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
template_path=_TEST_TEMPLATE_PATH,
parameter_values=_TEST_PIPELINE_PARAMETER_VALUES,
input_artifacts=_TEST_PIPELINE_INPUT_ARTIFACTS,
enable_caching=True,
)

pipeline_job_schedule = pipeline_job_schedules.PipelineJobSchedule(
pipeline_job=job,
display_name=_TEST_PIPELINE_JOB_SCHEDULE_DISPLAY_NAME,
)

with pytest.raises(RuntimeError) as e:
pipeline_job_schedule.update(
cron_expression=_TEST_UPDATED_PIPELINE_JOB_SCHEDULE_CRON_EXPRESSION,
max_run_count=_TEST_UPDATED_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT,
)

assert e.match(
regexp=r"Not updating PipelineJobSchedule: PipelineJobSchedule must be active or completed."
)

0 comments on commit 69c5f60

Please sign in to comment.