Skip to content

Commit

Permalink
feat: Add batch delete method in preview pipeline job class and unit …
Browse files Browse the repository at this point in the history
…test.

PiperOrigin-RevId: 599720798
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Jan 19, 2024
1 parent 066f32d commit b0b604e
Show file tree
Hide file tree
Showing 2 changed files with 175 additions and 2 deletions.
72 changes: 70 additions & 2 deletions google/cloud/aiplatform/preview/pipelinejob/pipeline_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,18 @@
# limitations under the License.
#

from typing import Optional
from typing import List, Optional

from google.cloud.aiplatform.pipeline_jobs import (
PipelineJob as PipelineJobGa,
)
from google.cloud.aiplatform import pipeline_job_schedules
from google.cloud.aiplatform_v1.services.pipeline_service import (
PipelineServiceClient as PipelineServiceClientGa,
)
from google.cloud import aiplatform_v1beta1
from google.cloud.aiplatform import compat, pipeline_job_schedules
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import utils

from google.cloud.aiplatform.metadata import constants as metadata_constants
from google.cloud.aiplatform.metadata import experiment_resources
Expand Down Expand Up @@ -112,3 +118,65 @@ def create_schedule(
network=network,
create_request_timeout=create_request_timeout,
)

@classmethod
def batch_delete(
cls,
names: List[str],
project: Optional[str] = None,
location: Optional[str] = None,
) -> aiplatform_v1beta1.BatchDeletePipelineJobsResponse:
"""
Example Usage:
pipeline_job = aiplatform.PipelineJob(
display_name='job_display_name',
template_path='your_pipeline.yaml',
)
pipeline_job.batch_delete(
names=['pipeline_job_name', 'pipeline_job_name2']
)
Args:
names (List[str]):
Required. The fully-qualified resource name or ID of the
Pipeline Jobs to batch delete. Example:
"projects/123/locations/us-central1/pipelineJobs/456"
or "456" when project and location are initialized or passed.
project (str):
Optional. Project containing the Pipeline Jobs to
batch delete. If not set, the project given to `aiplatform.init`
will be used.
location (str):
Optional. Location containing the Pipeline Jobs to
batch delete. If not set, the location given to `aiplatform.init`
will be used.
Returns:
BatchDeletePipelineJobsResponse contains PipelineJobs deleted.
"""
user_project = project or initializer.global_config.project
user_location = location or initializer.global_config.location
parent = initializer.global_config.common_location_path(
project=user_project, location=user_location
)
pipeline_jobs_names = [
utils.full_resource_name(
resource_name=name,
resource_noun="pipelineJobs",
parse_resource_name_method=PipelineServiceClientGa.parse_pipeline_job_path,
format_resource_name_method=PipelineServiceClientGa.pipeline_job_path,
project=user_project,
location=user_location,
)
for name in names
]
request = aiplatform_v1beta1.BatchDeletePipelineJobsRequest(
parent=parent, names=pipeline_jobs_names
)
client = cls._instantiate_client(
location=user_location,
appended_user_agent=["preview-pipeline-jobs-batch-delete"],
)
v1beta1_client = client.select_version(compat.V1BETA1)
operation = v1beta1_client.batch_delete_pipeline_jobs(request)
return operation.result()
105 changes: 105 additions & 0 deletions tests/unit/aiplatform/test_pipeline_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from urllib import request
from datetime import datetime

from google.api_core import operation as ga_operation
from google.auth import credentials as auth_credentials
from google.cloud import aiplatform
from google.cloud.aiplatform import base
Expand All @@ -43,6 +44,20 @@
from google.cloud.aiplatform.compat.services import (
pipeline_service_client,
)
from google.cloud.aiplatform_v1beta1.types import (
pipeline_service as PipelineServiceV1Beta1,
)
from google.cloud.aiplatform_v1beta1.services import (
pipeline_service as v1beta1_pipeline_service,
)
from google.cloud.aiplatform_v1beta1.types import (
pipeline_job as v1beta1_pipeline_job,
pipeline_state as v1beta1_pipeline_state,
context as v1beta1_context,
)
from google.cloud.aiplatform.preview.pipelinejob import (
pipeline_jobs as preview_pipeline_jobs,
)
from google.cloud.aiplatform.compat.types import (
pipeline_job as gca_pipeline_job,
pipeline_state as gca_pipeline_state,
Expand All @@ -52,7 +67,9 @@
_TEST_PROJECT = "test-project"
_TEST_LOCATION = "us-central1"
_TEST_PIPELINE_JOB_DISPLAY_NAME = "sample-pipeline-job-display-name"
_TEST_PIPELINE_JOB_DISPLAY_NAME_2 = "sample-pipeline-job-display-name-2"
_TEST_PIPELINE_JOB_ID = "sample-test-pipeline-202111111"
_TEST_PIPELINE_JOB_ID_2 = "sample-test-pipeline-202111112"
_TEST_GCS_BUCKET_NAME = "my-bucket"
_TEST_GCS_OUTPUT_DIRECTORY = f"gs://{_TEST_GCS_BUCKET_NAME}/output_artifacts/"
_TEST_CREDENTIALS = auth_credentials.AnonymousCredentials()
Expand All @@ -66,6 +83,7 @@
_TEST_RESERVED_IP_RANGES = ["vertex-ai-ip-range"]

_TEST_PIPELINE_JOB_NAME = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/pipelineJobs/{_TEST_PIPELINE_JOB_ID}"
_TEST_PIPELINE_JOB_NAME_2 = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/pipelineJobs/{_TEST_PIPELINE_JOB_ID_2}"
_TEST_PIPELINE_JOB_LIST_READ_MASK = field_mask.FieldMask(
paths=pipeline_constants._READ_MASK_FIELDS
)
Expand Down Expand Up @@ -237,6 +255,52 @@ def mock_pipeline_service_create():
yield mock_create_pipeline_job


@pytest.fixture
def mock_pipeline_v1beta1_service_batch_delete():
with mock.patch.object(
v1beta1_pipeline_service.PipelineServiceClient, "batch_delete_pipeline_jobs"
) as mock_batch_pipeline_jobs:
mock_batch_pipeline_jobs.return_value = (
make_batch_delete_pipeline_jobs_response()
)
mock_lro = mock.Mock(ga_operation.Operation)
mock_lro.result.return_value = make_batch_delete_pipeline_jobs_response()
mock_batch_pipeline_jobs.return_value = mock_lro
yield mock_batch_pipeline_jobs


def make_v1beta1_pipeline_job(name: str, state: v1beta1_pipeline_state.PipelineState):
return v1beta1_pipeline_job.PipelineJob(
name=name,
state=state,
create_time=_TEST_PIPELINE_CREATE_TIME,
service_account=_TEST_SERVICE_ACCOUNT,
network=_TEST_NETWORK,
job_detail=v1beta1_pipeline_job.PipelineJobDetail(
pipeline_run_context=v1beta1_context.Context(
name=name,
)
),
)


def make_batch_delete_pipeline_jobs_response():
response = PipelineServiceV1Beta1.BatchDeletePipelineJobsResponse()
response.pipeline_jobs.append(
make_v1beta1_pipeline_job(
_TEST_PIPELINE_JOB_NAME,
v1beta1_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED,
)
)
response.pipeline_jobs.append(
make_v1beta1_pipeline_job(
_TEST_PIPELINE_JOB_NAME_2,
v1beta1_pipeline_state.PipelineState.PIPELINE_STATE_FAILED,
)
)
return response


@pytest.fixture
def mock_pipeline_bucket_exists():
def mock_create_gcs_bucket_for_pipeline_artifacts_if_it_does_not_exist(
Expand Down Expand Up @@ -1974,3 +2038,44 @@ def test_get_associated_experiment_from_pipeline_returns_experiment(
assert associated_experiment.resource_name == _TEST_CONTEXT_NAME

assert add_context_children_mock.call_count == 1

@pytest.mark.usefixtures(
"mock_pipeline_service_get",
"mock_pipeline_v1beta1_service_batch_delete",
)
@pytest.mark.parametrize(
"job_spec",
[
_TEST_PIPELINE_SPEC_JSON,
_TEST_PIPELINE_SPEC_YAML,
_TEST_PIPELINE_JOB,
_TEST_PIPELINE_SPEC_LEGACY_JSON,
_TEST_PIPELINE_SPEC_LEGACY_YAML,
_TEST_PIPELINE_JOB_LEGACY,
],
)
def test_create_two_and_batch_delete_pipeline_jobs_returns_response(
self,
mock_load_yaml_and_json,
mock_pipeline_v1beta1_service_batch_delete,
):
aiplatform.init(
project=_TEST_PROJECT,
staging_bucket=_TEST_GCS_BUCKET_NAME,
credentials=_TEST_CREDENTIALS,
)

job = preview_pipeline_jobs._PipelineJob(
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
template_path=_TEST_TEMPLATE_PATH,
job_id=_TEST_PIPELINE_JOB_ID,
)

response = job.batch_delete(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
names=[_TEST_PIPELINE_JOB_ID, _TEST_PIPELINE_JOB_ID_2],
)

assert mock_pipeline_v1beta1_service_batch_delete.call_count == 1
assert len(response.pipeline_jobs) == 2

0 comments on commit b0b604e

Please sign in to comment.