From ce027351a52a7ab2d6d5a785e43d1db722a6b02b Mon Sep 17 00:00:00 2001 From: Vinny Senthil Date: Tue, 6 Apr 2021 17:20:31 -0600 Subject: [PATCH 1/9] Implementation of AiPlatformResourceNoun.list() --- google/cloud/aiplatform/base.py | 137 +++++++++++++++++++- google/cloud/aiplatform/datasets/dataset.py | 3 +- google/cloud/aiplatform/jobs.py | 4 + google/cloud/aiplatform/models.py | 2 + google/cloud/aiplatform/training_jobs.py | 2 + tests/unit/aiplatform/test_lro.py | 1 + 6 files changed, 144 insertions(+), 5 deletions(-) diff --git a/google/cloud/aiplatform/base.py b/google/cloud/aiplatform/base.py index 573af7807e..6b8a61b72b 100644 --- a/google/cloud/aiplatform/base.py +++ b/google/cloud/aiplatform/base.py @@ -17,16 +17,18 @@ import abc from concurrent import futures +import datetime import functools import inspect +import proto import threading from typing import Any, Callable, Dict, Optional, Sequence, Union -import proto - from google.auth import credentials as auth_credentials +from google.cloud import aiplatform from google.cloud.aiplatform import utils from google.cloud.aiplatform import initializer +from google.protobuf import field_mask_pb2 as field_mask class FutureManager(metaclass=abc.ABCMeta): @@ -223,7 +225,7 @@ class AiPlatformResourceNoun(metaclass=abc.ABCMeta): @property @classmethod @abc.abstractmethod - def client_class(cls) -> utils.AiPlatformServiceClient: + def client_class(cls) -> "utils.AiPlatformServiceClient": """Client class required to interact with resource.""" pass @@ -240,6 +242,12 @@ def _getter_method(cls) -> str: """Name of getter method of client class for retrieving the resource.""" pass + @property + @abc.abstractmethod + def _list_method(cls) -> str: + """Name of list method of client class for listing resources.""" + pass + @property @abc.abstractmethod def _delete_method(cls) -> str: @@ -278,7 +286,7 @@ def _instantiate_client( cls, location: Optional[str] = None, credentials: Optional[auth_credentials.Credentials] = None, - ) -> utils.AiPlatformServiceClient: + ) -> "utils.AiPlatformServiceClient": """Helper method to instantiate service client for resource noun. Args: @@ -334,6 +342,16 @@ def display_name(self) -> str: """Display name of this resource.""" return self._gca_resource.display_name + @property + def create_time(self) -> datetime.datetime: + """Time this resource was created.""" + return self._gca_resource.create_time + + @property + def update_time(self) -> datetime.datetime: + """Time this resource was last updated.""" + return self._gca_resource.update_time + def optional_sync( construct_object_on_arg: Optional[str] = None, @@ -549,6 +567,117 @@ def _sync_object_with_future_result( if value: setattr(self, attribute, value) + def _construct_sdk_resource_from_gapic( + self, gapic_resource: proto.Message + ) -> AiPlatformResourceNoun: + """Given a GAPIC object, return the SDK representation.""" + sdk_resource = self._empty_constructor() + sdk_resource._gca_resource = gapic_resource + return sdk_resource + + # TODO(b/144545165) - Improve documentation for list filtering once available + @classmethod + def list( + cls, + filter: Optional[str] = None, + order_by: Optional[str] = None, + page_size: Optional[int] = None, + read_mask: Optional[field_mask.FieldMask] = None, + ) -> Sequence[AiPlatformResourceNoun]: + """List all instances of this AI Platform Resource. + + Example Usage: + + aiplatform.BatchPredictionJobs.list( + filter='state="JOB_STATE_SUCCEEDED" AND display_name="my_job"', + ) + + aiplatform.Model.list(order_by="create_time desc, display_name") + + Args: + filter (str): + Optional. An expression for filtering the results of the request. + For field names both snake_case and camelCase are supported. + order_by (str): + Optional. A comma-separated list of fields to order by, sorted in + ascending order. Use "desc" after a field name for descending. + Supported fields: `display_name`, `create_time`, `update_time` + page_size (int): + Optional. The standard list page size. + read_mask (field_mask.FieldMask): + Optional. Mask specifying which fields to read. + + Returns: + Sequence[AiPlatformResourceNoun] - A list of SDK resource objects + """ + + self = cls._empty_constructor() + + resource_list_method = getattr(self.api_client, self._list_method) + order_locally = False + + list_request = { + "parent": initializer.global_config.common_location_path(), + "filter": filter, + "page_size": page_size, + "read_mask": read_mask, + } + + # If list method does not offer `order_by` field, order locally + if ( + issubclass( + type(self), + ( + aiplatform.jobs._Job, + aiplatform.training_jobs._TrainingJob, + aiplatform.models.Model, + ), + ) + and order_by + ): + order_locally = True + elif order_by: + list_request["order_by"] = order_by + + resource_list = resource_list_method(request=list_request) or [] + + # Only return objects specific to the calling subclass, + # for example TabularDataset.list() only lists TabularDatasets + if issubclass(type(self), aiplatform.datasets.Dataset): + final_list = [ + self._construct_sdk_resource_from_gapic(gapic_resource) + for gapic_resource in resource_list + if gapic_resource.metadata_schema_uri + in self._supported_metadata_schema_uris + ] + + elif issubclass(type(self), aiplatform.training_jobs._TrainingJob): + final_list = [ + self._construct_sdk_resource_from_gapic(gapic_resource) + for gapic_resource in resource_list + if gapic_resource.training_task_definition + in self._supported_training_schemas + ] + + else: + final_list = [ + self._construct_sdk_resource_from_gapic(gapic_resource) + for gapic_resource in resource_list + ] + + # Client-side sorting when API doesn't support `order_by` + if order_locally: + desc = "desc" in order_by + order_by = order_by.replace("desc", "") + order_by = order_by.split(",") + + final_list.sort( + key=lambda x: tuple(getattr(x, field.strip()) for field in order_by), + reverse=desc, + ) + + return final_list + @optional_sync() def delete(self, sync: bool = True) -> None: """Deletes this AI Platform resource. WARNING: This deletion is permament. diff --git a/google/cloud/aiplatform/datasets/dataset.py b/google/cloud/aiplatform/datasets/dataset.py index 872f736279..20e514c16d 100644 --- a/google/cloud/aiplatform/datasets/dataset.py +++ b/google/cloud/aiplatform/datasets/dataset.py @@ -40,9 +40,10 @@ class Dataset(base.AiPlatformResourceNounWithFutureManager): _is_client_prediction_client = False _resource_noun = "datasets" _getter_method = "get_dataset" + _list_method = "list_datasets" _delete_method = "delete_dataset" - _supported_metadata_schema_uris: Optional[Tuple[str]] = None + _supported_metadata_schema_uris: Optional[Tuple[str]] = () def __init__( self, diff --git a/google/cloud/aiplatform/jobs.py b/google/cloud/aiplatform/jobs.py index 7315a6f662..4587c9281d 100644 --- a/google/cloud/aiplatform/jobs.py +++ b/google/cloud/aiplatform/jobs.py @@ -176,6 +176,7 @@ class BatchPredictionJob(_Job): _resource_noun = "batchPredictionJobs" _getter_method = "get_batch_prediction_job" + _list_method = "list_batch_prediction_jobs" _cancel_method = "cancel_batch_prediction_job" _delete_method = "delete_batch_prediction_job" _job_type = "batch-predictions" @@ -676,6 +677,7 @@ def iter_outputs( class CustomJob(_Job): _resource_noun = "customJobs" _getter_method = "get_custom_job" + _list_method = "list_custom_job" _cancel_method = "cancel_custom_job" _delete_method = "delete_custom_job" _job_type = "training" @@ -685,6 +687,7 @@ class CustomJob(_Job): class DataLabelingJob(_Job): _resource_noun = "dataLabelingJobs" _getter_method = "get_data_labeling_job" + _list_method = "list_data_labeling_jobs" _cancel_method = "cancel_data_labeling_job" _delete_method = "delete_data_labeling_job" _job_type = "labeling-tasks" @@ -694,6 +697,7 @@ class DataLabelingJob(_Job): class HyperparameterTuningJob(_Job): _resource_noun = "hyperparameterTuningJobs" _getter_method = "get_hyperparameter_tuning_job" + _list_method = "list_hyperparameter_tuning_jobs" _cancel_method = "cancel_hyperparameter_tuning_job" _delete_method = "delete_hyperparameter_tuning_job" pass diff --git a/google/cloud/aiplatform/models.py b/google/cloud/aiplatform/models.py index ae1ac51bfd..23df01fc93 100644 --- a/google/cloud/aiplatform/models.py +++ b/google/cloud/aiplatform/models.py @@ -71,6 +71,7 @@ class Endpoint(base.AiPlatformResourceNounWithFutureManager): _is_client_prediction_client = False _resource_noun = "endpoints" _getter_method = "get_endpoint" + _list_method = "list_endpoints" _delete_method = "delete_endpoint" def __init__( @@ -1083,6 +1084,7 @@ class Model(base.AiPlatformResourceNounWithFutureManager): _is_client_prediction_client = False _resource_noun = "models" _getter_method = "get_model" + _list_method = "list_models" _delete_method = "delete_model" @property diff --git a/google/cloud/aiplatform/training_jobs.py b/google/cloud/aiplatform/training_jobs.py index 6046291e68..402a0d8b3e 100644 --- a/google/cloud/aiplatform/training_jobs.py +++ b/google/cloud/aiplatform/training_jobs.py @@ -59,6 +59,7 @@ import proto + logging.basicConfig(level=logging.INFO, stream=sys.stdout) _LOGGER = logging.getLogger(__name__) @@ -77,6 +78,7 @@ class _TrainingJob(base.AiPlatformResourceNounWithFutureManager): _is_client_prediction_client = False _resource_noun = "trainingPipelines" _getter_method = "get_training_pipeline" + _list_method = "list_training_pipelines" _delete_method = "delete_training_pipeline" def __init__( diff --git a/tests/unit/aiplatform/test_lro.py b/tests/unit/aiplatform/test_lro.py index 26685d4f15..936a12e440 100644 --- a/tests/unit/aiplatform/test_lro.py +++ b/tests/unit/aiplatform/test_lro.py @@ -49,6 +49,7 @@ class AiPlatformResourceNounImpl(base.AiPlatformResourceNoun): _is_client_prediction_client = False _resource_noun = None _getter_method = None + _list_method = None _delete_method = None From 9a00cfd620b7e93231ae04864d9929ed53735751 Mon Sep 17 00:00:00 2001 From: Vinny Senthil Date: Tue, 6 Apr 2021 19:03:26 -0600 Subject: [PATCH 2/9] Add list() tests --- tests/unit/aiplatform/test_datasets.py | 50 +++++++++- tests/unit/aiplatform/test_jobs.py | 2 + tests/unit/aiplatform/test_models.py | 122 +++++++++++++++++++------ 3 files changed, 144 insertions(+), 30 deletions(-) diff --git a/tests/unit/aiplatform/test_datasets.py b/tests/unit/aiplatform/test_datasets.py index 60458bcc70..6ed0e8d860 100644 --- a/tests/unit/aiplatform/test_datasets.py +++ b/tests/unit/aiplatform/test_datasets.py @@ -106,6 +106,21 @@ # misc _TEST_OUTPUT_DIR = "gs://my-output-bucket" +_TEST_TABULAR_DATASET_LIST = [ + GapicDataset( + display_name="a", metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TABULAR + ), + GapicDataset( + display_name="b", metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TABULAR + ), + GapicDataset( + display_name="c", metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TABULAR + ), +] + +_TEST_LIST_FILTER = 'display_name="abc"' +_TEST_LIST_ORDER_BY = "create_time desc" + @pytest.fixture def get_dataset_mock(): @@ -224,6 +239,13 @@ def export_data_mock(): yield export_data_mock +@pytest.fixture +def list_datasets_mock(): + with patch.object(DatasetServiceClient, "list_datasets") as list_datasets_mock: + list_datasets_mock.return_value = _TEST_TABULAR_DATASET_LIST + yield list_datasets_mock + + # TODO(b/171333554): Move reusable test fixtures to conftest.py file class TestDataset: def setup_method(self): @@ -669,18 +691,19 @@ class TestTabularDataset: def setup_method(self): reload(initializer) reload(aiplatform) + aiplatform.init(project=_TEST_PROJECT) def teardown_method(self): initializer.global_pool.shutdown(wait=True) def test_init_dataset_tabular(self, get_dataset_tabular_mock): - aiplatform.init(project=_TEST_PROJECT) + datasets.TabularDataset(dataset_name=_TEST_NAME) get_dataset_tabular_mock.assert_called_once_with(name=_TEST_NAME) @pytest.mark.usefixtures("get_dataset_image_mock") def test_init_dataset_non_tabular(self): - aiplatform.init(project=_TEST_PROJECT) + with pytest.raises(ValueError): datasets.TabularDataset(dataset_name=_TEST_NAME) @@ -716,7 +739,6 @@ def test_create_dataset_with_default_encryption_key( @pytest.mark.usefixtures("get_dataset_tabular_mock") @pytest.mark.parametrize("sync", [True, False]) def test_create_dataset(self, create_dataset_mock, sync): - aiplatform.init(project=_TEST_PROJECT) my_dataset = datasets.TabularDataset.create( display_name=_TEST_DISPLAY_NAME, @@ -743,13 +765,33 @@ def test_create_dataset(self, create_dataset_mock, sync): @pytest.mark.usefixtures("get_dataset_tabular_mock") def test_no_import_data_method(self): - aiplatform.init(project=_TEST_PROJECT) my_dataset = datasets.TabularDataset(dataset_name=_TEST_NAME) with pytest.raises(NotImplementedError): my_dataset.import_data() + def test_list_dataset(self, list_datasets_mock): + + ds_list = aiplatform.TabularDataset.list( + filter=_TEST_LIST_FILTER, order_by=_TEST_LIST_ORDER_BY + ) + + list_datasets_mock.assert_called_once_with( + request={ + "parent": _TEST_PARENT, + "filter": _TEST_LIST_FILTER, + "order_by": _TEST_LIST_ORDER_BY, + "page_size": None, + "read_mask": None, + } + ) + + assert len(ds_list) == len(_TEST_TABULAR_DATASET_LIST) + + for ds in ds_list: + assert type(ds) == aiplatform.TabularDataset + class TestTextDataset: def setup_method(self): diff --git a/tests/unit/aiplatform/test_jobs.py b/tests/unit/aiplatform/test_jobs.py index 255ba05088..30cb4622f7 100644 --- a/tests/unit/aiplatform/test_jobs.py +++ b/tests/unit/aiplatform/test_jobs.py @@ -131,6 +131,7 @@ ) _TEST_JOB_GET_METHOD_NAME = "get_fake_job" +_TEST_JOB_LIST_METHOD_NAME = "list_fake_job" _TEST_JOB_CANCEL_METHOD_NAME = "cancel_fake_job" _TEST_JOB_DELETE_METHOD_NAME = "delete_fake_job" _TEST_JOB_RESOURCE_NAME = f"{_TEST_PARENT}/fakeJobs/{_TEST_ID}" @@ -160,6 +161,7 @@ class FakeJob(jobs._Job): _job_type = "fake-job" _resource_noun = "fakeJobs" _getter_method = _TEST_JOB_GET_METHOD_NAME + _list_method = _TEST_JOB_LIST_METHOD_NAME _cancel_method = _TEST_JOB_CANCEL_METHOD_NAME _delete_method = _TEST_JOB_DELETE_METHOD_NAME resource_name = _TEST_JOB_RESOURCE_NAME diff --git a/tests/unit/aiplatform/test_models.py b/tests/unit/aiplatform/test_models.py index 8d32bbe2c6..9aff376520 100644 --- a/tests/unit/aiplatform/test_models.py +++ b/tests/unit/aiplatform/test_models.py @@ -18,6 +18,7 @@ import importlib import pytest from unittest import mock +from datetime import datetime, timedelta from google.api_core import operation as ga_operation from google.auth import credentials as auth_credentials @@ -112,6 +113,24 @@ kms_key_name=_TEST_ENCRYPTION_KEY_NAME ) +_TEST_OUTPUT_DIR = "gs://my-output-bucket" + +_TEST_MODEL_LIST = [ + gca_model.Model( + display_name="aac", create_time=datetime.now() - timedelta(minutes=15) + ), + gca_model.Model( + display_name="aab", create_time=datetime.now() - timedelta(minutes=5) + ), + gca_model.Model( + display_name="aaa", create_time=datetime.now() - timedelta(minutes=10) + ), +] + +_TEST_LIST_FILTER = 'display_name="abc"' +_TEST_LIST_ORDER_BY_CREATE_TIME = "create_time desc" +_TEST_LIST_ORDER_BY_DISPLAY_NAME = "display_name" + @pytest.fixture def get_endpoint_mock(): @@ -146,6 +165,13 @@ def delete_model_mock(): yield delete_model_mock +@pytest.fixture +def list_models_mock(): + with mock.patch.object(ModelServiceClient, "list_models") as list_models_mock: + list_models_mock.return_value = _TEST_MODEL_LIST + yield list_models_mock + + @pytest.fixture def deploy_model_mock(): with mock.patch.object(EndpointServiceClient, "deploy_model") as deploy_model_mock: @@ -192,6 +218,7 @@ class TestModel: def setup_method(self): importlib.reload(initializer) importlib.reload(aiplatform) + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) def teardown_method(self): initializer.global_pool.shutdown(wait=True) @@ -236,7 +263,7 @@ def test_constructor_create_client_with_custom_location(self): ) def test_constructor_creates_client_with_custom_credentials(self): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + with mock.patch.object( initializer.global_config, "create_client" ) as create_client_mock: @@ -252,7 +279,7 @@ def test_constructor_creates_client_with_custom_credentials(self): ) def test_constructor_gets_model(self): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + with mock.patch.object( initializer.global_config, "create_client" ) as create_client_mock: @@ -268,7 +295,7 @@ def test_constructor_gets_model(self): ) def test_constructor_gets_model_with_custom_project(self): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + with mock.patch.object( initializer.global_config, "create_client" ) as create_client_mock: @@ -283,7 +310,7 @@ def test_constructor_gets_model_with_custom_project(self): ) def test_constructor_gets_model_with_custom_location(self): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + with mock.patch.object( initializer.global_config, "create_client" ) as create_client_mock: @@ -300,7 +327,6 @@ def test_constructor_gets_model_with_custom_location(self): @pytest.mark.parametrize("sync", [True, False]) def test_upload_uploads_and_gets_model(self, sync): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) with mock.patch.object( initializer.global_config, "create_client" ) as create_client_mock: @@ -348,8 +374,6 @@ def test_upload_uploads_and_gets_model(self, sync): def test_upload_raises_with_impartial_explanation_spec(self): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) - with pytest.raises(ValueError) as e: models.Model.upload( display_name=_TEST_MODEL_NAME, @@ -364,7 +388,6 @@ def test_upload_raises_with_impartial_explanation_spec(self): @pytest.mark.parametrize("sync", [True, False]) def test_upload_uploads_and_gets_model_with_all_args(self, sync): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) with mock.patch.object( initializer.global_config, "create_client" ) as create_client_mock: @@ -449,7 +472,6 @@ def test_upload_uploads_and_gets_model_with_all_args(self, sync): @pytest.mark.parametrize("sync", [True, False]) def test_upload_uploads_and_gets_model_with_custom_project(self, sync): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) with mock.patch.object( initializer.global_config, "create_client" ) as create_client_mock: @@ -500,7 +522,7 @@ def test_upload_uploads_and_gets_model_with_custom_project(self, sync): @pytest.mark.parametrize("sync", [True, False]) def test_upload_uploads_and_gets_model_with_custom_location(self, sync): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + with mock.patch.object( initializer.global_config, "create_client" ) as create_client_mock: @@ -552,7 +574,7 @@ def test_upload_uploads_and_gets_model_with_custom_location(self, sync): @pytest.mark.usefixtures("get_endpoint_mock", "get_model_mock") @pytest.mark.parametrize("sync", [True, False]) def test_deploy(self, deploy_model_mock, sync): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_model = models.Model(_TEST_ID) test_endpoint = models.Endpoint(_TEST_ID) @@ -581,7 +603,7 @@ def test_deploy(self, deploy_model_mock, sync): ) @pytest.mark.parametrize("sync", [True, False]) def test_deploy_no_endpoint(self, deploy_model_mock, sync): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_model = models.Model(_TEST_ID) test_endpoint = test_model.deploy(sync=sync) @@ -608,7 +630,7 @@ def test_deploy_no_endpoint(self, deploy_model_mock, sync): ) @pytest.mark.parametrize("sync", [True, False]) def test_deploy_no_endpoint_dedicated_resources(self, deploy_model_mock, sync): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_model = models.Model(_TEST_ID) test_endpoint = test_model.deploy( machine_type=_TEST_MACHINE_TYPE, @@ -645,7 +667,7 @@ def test_deploy_no_endpoint_dedicated_resources(self, deploy_model_mock, sync): ) @pytest.mark.parametrize("sync", [True, False]) def test_deploy_no_endpoint_with_explanations(self, deploy_model_mock, sync): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_model = models.Model(_TEST_ID) test_endpoint = test_model.deploy( machine_type=_TEST_MACHINE_TYPE, @@ -687,7 +709,7 @@ def test_deploy_no_endpoint_with_explanations(self, deploy_model_mock, sync): "get_endpoint_mock", "get_model_mock", "create_endpoint_mock" ) def test_deploy_raises_with_impartial_explanation_spec(self): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_model = models.Model(_TEST_ID) with pytest.raises(ValueError) as e: @@ -755,9 +777,7 @@ def test_init_aiplatform_with_encryption_key_name_and_batch_predict_gcs_source_a def test_batch_predict_gcs_source_and_dest( self, create_batch_prediction_job_mock, sync ): - aiplatform.init( - project=_TEST_PROJECT, location=_TEST_LOCATION, - ) + test_model = models.Model(_TEST_ID) # Make SDK batch_predict method call @@ -801,7 +821,7 @@ def test_batch_predict_gcs_source_and_dest( def test_batch_predict_gcs_source_bq_dest( self, create_batch_prediction_job_mock, sync ): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_model = models.Model(_TEST_ID) # Make SDK batch_predict method call @@ -843,7 +863,7 @@ def test_batch_predict_gcs_source_bq_dest( @pytest.mark.parametrize("sync", [True, False]) @pytest.mark.usefixtures("get_model_mock", "get_batch_prediction_job_mock") def test_batch_predict_with_all_args(self, create_batch_prediction_job_mock, sync): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_model = models.Model(_TEST_ID) creds = auth_credentials.AnonymousCredentials() @@ -914,7 +934,7 @@ def test_batch_predict_with_all_args(self, create_batch_prediction_job_mock, syn @pytest.mark.usefixtures("get_model_mock", "get_batch_prediction_job_mock") def test_batch_predict_no_source(self, create_batch_prediction_job_mock): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_model = models.Model(_TEST_ID) # Make SDK batch_predict method call without source @@ -928,7 +948,7 @@ def test_batch_predict_no_source(self, create_batch_prediction_job_mock): @pytest.mark.usefixtures("get_model_mock", "get_batch_prediction_job_mock") def test_batch_predict_two_sources(self, create_batch_prediction_job_mock): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_model = models.Model(_TEST_ID) # Make SDK batch_predict method call with two sources @@ -944,7 +964,7 @@ def test_batch_predict_two_sources(self, create_batch_prediction_job_mock): @pytest.mark.usefixtures("get_model_mock", "get_batch_prediction_job_mock") def test_batch_predict_no_destination(self): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_model = models.Model(_TEST_ID) # Make SDK batch_predict method call without destination @@ -958,7 +978,7 @@ def test_batch_predict_no_destination(self): @pytest.mark.usefixtures("get_model_mock", "get_batch_prediction_job_mock") def test_batch_predict_wrong_instance_format(self): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_model = models.Model(_TEST_ID) # Make SDK batch_predict method call @@ -974,7 +994,7 @@ def test_batch_predict_wrong_instance_format(self): @pytest.mark.usefixtures("get_model_mock", "get_batch_prediction_job_mock") def test_batch_predict_wrong_prediction_format(self): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_model = models.Model(_TEST_ID) # Make SDK batch_predict method call @@ -991,7 +1011,7 @@ def test_batch_predict_wrong_prediction_format(self): @pytest.mark.usefixtures("get_model_mock") @pytest.mark.parametrize("sync", [True, False]) def test_delete_model(self, delete_model_mock, sync): - aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_model = models.Model(_TEST_ID) test_model.delete(sync=sync) @@ -999,3 +1019,53 @@ def test_delete_model(self, delete_model_mock, sync): test_model.wait() delete_model_mock.assert_called_once_with(name=test_model.resource_name) + + def test_list_model_order_by_time(self, list_models_mock): + """Test call to Model.list() and ensure list is returned in descending order of create_time""" + + ds_list = aiplatform.Model.list( + filter=_TEST_LIST_FILTER, order_by=_TEST_LIST_ORDER_BY_CREATE_TIME + ) + + # `order_by` is not passed to API since it is not an accepted field + list_models_mock.assert_called_once_with( + request={ + "parent": _TEST_PARENT, + "filter": _TEST_LIST_FILTER, + "page_size": None, + "read_mask": None, + } + ) + + assert len(ds_list) == len(_TEST_MODEL_LIST) + + for ds in ds_list: + assert type(ds) == aiplatform.Model + + assert ds_list[0].create_time > ds_list[1].create_time > ds_list[2].create_time + + def test_list_model_order_by_display_name(self, list_models_mock): + """Test call to Model.list() and ensure list is returned in order of display_name""" + + ds_list = aiplatform.Model.list( + filter=_TEST_LIST_FILTER, order_by=_TEST_LIST_ORDER_BY_DISPLAY_NAME + ) + + # `order_by` is not passed to API since it is not an accepted field + list_models_mock.assert_called_once_with( + request={ + "parent": _TEST_PARENT, + "filter": _TEST_LIST_FILTER, + "page_size": None, + "read_mask": None, + } + ) + + assert len(ds_list) == len(_TEST_MODEL_LIST) + + for ds in ds_list: + assert type(ds) == aiplatform.Model + + assert ( + ds_list[0].display_name < ds_list[1].display_name < ds_list[2].display_name + ) From ea6c0a1504a4af11910823241e779bb50c7e3cfb Mon Sep 17 00:00:00 2001 From: Vinny Senthil Date: Wed, 7 Apr 2021 10:53:43 -0600 Subject: [PATCH 3/9] Set credentials once for all list objects --- google/cloud/aiplatform/base.py | 38 +++++++++++++++++++-------------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/google/cloud/aiplatform/base.py b/google/cloud/aiplatform/base.py index 6b8a61b72b..85df6096e9 100644 --- a/google/cloud/aiplatform/base.py +++ b/google/cloud/aiplatform/base.py @@ -568,10 +568,12 @@ def _sync_object_with_future_result( setattr(self, attribute, value) def _construct_sdk_resource_from_gapic( - self, gapic_resource: proto.Message + self, + gapic_resource: proto.Message, + credentials: Optional[auth_credentials.Credentials] = None, ) -> AiPlatformResourceNoun: """Given a GAPIC object, return the SDK representation.""" - sdk_resource = self._empty_constructor() + sdk_resource = self._empty_constructor(credentials=credentials) sdk_resource._gca_resource = gapic_resource return sdk_resource @@ -610,9 +612,17 @@ def list( Returns: Sequence[AiPlatformResourceNoun] - A list of SDK resource objects """ + _UNSUPPORTED_LIST_ORDER_BY_TYPES = ( + aiplatform.jobs._Job, + aiplatform.models.Endpoint, + aiplatform.models.Model, + aiplatform.training_jobs._TrainingJob, + ) self = cls._empty_constructor() + creds = initializer.global_config.credentials + resource_list_method = getattr(self.api_client, self._list_method) order_locally = False @@ -624,17 +634,7 @@ def list( } # If list method does not offer `order_by` field, order locally - if ( - issubclass( - type(self), - ( - aiplatform.jobs._Job, - aiplatform.training_jobs._TrainingJob, - aiplatform.models.Model, - ), - ) - and order_by - ): + if order_by and issubclass(type(self), _UNSUPPORTED_LIST_ORDER_BY_TYPES): order_locally = True elif order_by: list_request["order_by"] = order_by @@ -645,7 +645,9 @@ def list( # for example TabularDataset.list() only lists TabularDatasets if issubclass(type(self), aiplatform.datasets.Dataset): final_list = [ - self._construct_sdk_resource_from_gapic(gapic_resource) + self._construct_sdk_resource_from_gapic( + gapic_resource, credentials=creds + ) for gapic_resource in resource_list if gapic_resource.metadata_schema_uri in self._supported_metadata_schema_uris @@ -653,7 +655,9 @@ def list( elif issubclass(type(self), aiplatform.training_jobs._TrainingJob): final_list = [ - self._construct_sdk_resource_from_gapic(gapic_resource) + self._construct_sdk_resource_from_gapic( + gapic_resource, credentials=creds + ) for gapic_resource in resource_list if gapic_resource.training_task_definition in self._supported_training_schemas @@ -661,7 +665,9 @@ def list( else: final_list = [ - self._construct_sdk_resource_from_gapic(gapic_resource) + self._construct_sdk_resource_from_gapic( + gapic_resource, credentials=creds + ) for gapic_resource in resource_list ] From 7b06b61cdcb4566c341692e32bbaa474ee4db475 Mon Sep 17 00:00:00 2001 From: Vinny Senthil Date: Wed, 7 Apr 2021 11:45:44 -0600 Subject: [PATCH 4/9] Lint base.py --- google/cloud/aiplatform/base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/google/cloud/aiplatform/base.py b/google/cloud/aiplatform/base.py index 28ceda9ccc..656d22a487 100644 --- a/google/cloud/aiplatform/base.py +++ b/google/cloud/aiplatform/base.py @@ -360,6 +360,7 @@ def create_time(self) -> datetime.datetime: def update_time(self) -> datetime.datetime: """Time this resource was last updated.""" return self._gca_resource.update_time + def __repr__(self) -> str: return f"{object.__repr__(self)} \nresource name: {self.resource_name}" From 42b35e3a17c1a589f44fb85aa4dae5f215b0466a Mon Sep 17 00:00:00 2001 From: Vinny Senthil Date: Fri, 9 Apr 2021 00:00:04 -0600 Subject: [PATCH 5/9] Address most requested changes --- google/cloud/aiplatform/base.py | 39 +++++++++++++-------- google/cloud/aiplatform/datasets/dataset.py | 2 +- tests/unit/aiplatform/test_datasets.py | 13 ++++--- tests/unit/aiplatform/test_models.py | 14 ++------ 4 files changed, 36 insertions(+), 32 deletions(-) diff --git a/google/cloud/aiplatform/base.py b/google/cloud/aiplatform/base.py index 656d22a487..f52d1d10c5 100644 --- a/google/cloud/aiplatform/base.py +++ b/google/cloud/aiplatform/base.py @@ -22,13 +22,12 @@ import inspect import proto import threading -from typing import Any, Callable, Dict, Optional, Sequence, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Union from google.auth import credentials as auth_credentials from google.cloud import aiplatform from google.cloud.aiplatform import utils from google.cloud.aiplatform import initializer -from google.protobuf import field_mask_pb2 as field_mask class FutureManager(metaclass=abc.ABCMeta): @@ -359,6 +358,7 @@ def create_time(self) -> datetime.datetime: @property def update_time(self) -> datetime.datetime: """Time this resource was last updated.""" + self._sync_gca_resource() return self._gca_resource.update_time def __repr__(self) -> str: @@ -582,22 +582,28 @@ def _sync_object_with_future_result( def _construct_sdk_resource_from_gapic( self, gapic_resource: proto.Message, + project: Optional[str] = None, + location: Optional[str] = None, credentials: Optional[auth_credentials.Credentials] = None, ) -> AiPlatformResourceNoun: """Given a GAPIC object, return the SDK representation.""" - sdk_resource = self._empty_constructor(credentials=credentials) + sdk_resource = self._empty_constructor( + project=project, location=location, credentials=credentials + ) sdk_resource._gca_resource = gapic_resource return sdk_resource - # TODO(b/144545165) - Improve documentation for list filtering once available + # TODO(b/144545165): Improve documentation for list filtering once available + # TODO(b/184910159): Expose `page_size` field in list method @classmethod def list( cls, filter: Optional[str] = None, order_by: Optional[str] = None, - page_size: Optional[int] = None, - read_mask: Optional[field_mask.FieldMask] = None, - ) -> Sequence[AiPlatformResourceNoun]: + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> List[AiPlatformResourceNoun]: """List all instances of this AI Platform Resource. Example Usage: @@ -616,10 +622,15 @@ def list( Optional. A comma-separated list of fields to order by, sorted in ascending order. Use "desc" after a field name for descending. Supported fields: `display_name`, `create_time`, `update_time` - page_size (int): - Optional. The standard list page size. - read_mask (field_mask.FieldMask): - Optional. Mask specifying which fields to read. + project (str): + Optional. Project to retrieve list from. If not set, project + set in aiplatform.init will be used. + location (str): + Optional. Location to retrieve list from. If not set, location + set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to retrieve list. Overrides + credentials set in aiplatform.init. Returns: Sequence[AiPlatformResourceNoun] - A list of SDK resource objects @@ -631,7 +642,9 @@ def list( aiplatform.training_jobs._TrainingJob, ) - self = cls._empty_constructor() + self = cls._empty_constructor( + project=project, location=location, credentials=credentials + ) creds = initializer.global_config.credentials @@ -641,8 +654,6 @@ def list( list_request = { "parent": initializer.global_config.common_location_path(), "filter": filter, - "page_size": page_size, - "read_mask": read_mask, } # If list method does not offer `order_by` field, order locally diff --git a/google/cloud/aiplatform/datasets/dataset.py b/google/cloud/aiplatform/datasets/dataset.py index 20e514c16d..48ccd46c7f 100644 --- a/google/cloud/aiplatform/datasets/dataset.py +++ b/google/cloud/aiplatform/datasets/dataset.py @@ -43,7 +43,7 @@ class Dataset(base.AiPlatformResourceNounWithFutureManager): _list_method = "list_datasets" _delete_method = "delete_dataset" - _supported_metadata_schema_uris: Optional[Tuple[str]] = () + _supported_metadata_schema_uris: Tuple[str] = () def __init__( self, diff --git a/tests/unit/aiplatform/test_datasets.py b/tests/unit/aiplatform/test_datasets.py index 6ed0e8d860..71f592c9b8 100644 --- a/tests/unit/aiplatform/test_datasets.py +++ b/tests/unit/aiplatform/test_datasets.py @@ -106,13 +106,17 @@ # misc _TEST_OUTPUT_DIR = "gs://my-output-bucket" -_TEST_TABULAR_DATASET_LIST = [ +_TEST_DATASET_LIST = [ GapicDataset( display_name="a", metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TABULAR ), + GapicDataset( + display_name="d", metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_NONTABULAR + ), GapicDataset( display_name="b", metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TABULAR ), + GapicDataset(display_name="e", metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TEXT), GapicDataset( display_name="c", metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TABULAR ), @@ -242,7 +246,7 @@ def export_data_mock(): @pytest.fixture def list_datasets_mock(): with patch.object(DatasetServiceClient, "list_datasets") as list_datasets_mock: - list_datasets_mock.return_value = _TEST_TABULAR_DATASET_LIST + list_datasets_mock.return_value = _TEST_DATASET_LIST yield list_datasets_mock @@ -782,12 +786,11 @@ def test_list_dataset(self, list_datasets_mock): "parent": _TEST_PARENT, "filter": _TEST_LIST_FILTER, "order_by": _TEST_LIST_ORDER_BY, - "page_size": None, - "read_mask": None, } ) - assert len(ds_list) == len(_TEST_TABULAR_DATASET_LIST) + # Ensure returned list is smaller since it filtered out non-tabular datasets + assert len(ds_list) < len(_TEST_DATASET_LIST) for ds in ds_list: assert type(ds) == aiplatform.TabularDataset diff --git a/tests/unit/aiplatform/test_models.py b/tests/unit/aiplatform/test_models.py index 98c1161d61..9101532f83 100644 --- a/tests/unit/aiplatform/test_models.py +++ b/tests/unit/aiplatform/test_models.py @@ -1030,12 +1030,7 @@ def test_list_model_order_by_time(self, list_models_mock): # `order_by` is not passed to API since it is not an accepted field list_models_mock.assert_called_once_with( - request={ - "parent": _TEST_PARENT, - "filter": _TEST_LIST_FILTER, - "page_size": None, - "read_mask": None, - } + request={"parent": _TEST_PARENT, "filter": _TEST_LIST_FILTER} ) assert len(ds_list) == len(_TEST_MODEL_LIST) @@ -1054,12 +1049,7 @@ def test_list_model_order_by_display_name(self, list_models_mock): # `order_by` is not passed to API since it is not an accepted field list_models_mock.assert_called_once_with( - request={ - "parent": _TEST_PARENT, - "filter": _TEST_LIST_FILTER, - "page_size": None, - "read_mask": None, - } + request={"parent": _TEST_PARENT, "filter": _TEST_LIST_FILTER} ) assert len(ds_list) == len(_TEST_MODEL_LIST) From cabef1b2a743cd8476aec5b3128d688ff62871f8 Mon Sep 17 00:00:00 2001 From: Vinny Senthil Date: Fri, 9 Apr 2021 01:49:22 -0600 Subject: [PATCH 6/9] Add two private list methods with subclass list methods --- google/cloud/aiplatform/base.py | 148 ++++++++++---------- google/cloud/aiplatform/datasets/dataset.py | 56 +++++++- google/cloud/aiplatform/jobs.py | 49 ++++++- google/cloud/aiplatform/models.py | 94 +++++++++++++ google/cloud/aiplatform/training_jobs.py | 54 +++++++ tests/unit/aiplatform/test_datasets.py | 6 +- 6 files changed, 329 insertions(+), 78 deletions(-) diff --git a/google/cloud/aiplatform/base.py b/google/cloud/aiplatform/base.py index f52d1d10c5..ca5fcc62d2 100644 --- a/google/cloud/aiplatform/base.py +++ b/google/cloud/aiplatform/base.py @@ -25,7 +25,6 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Union from google.auth import credentials as auth_credentials -from google.cloud import aiplatform from google.cloud.aiplatform import utils from google.cloud.aiplatform import initializer @@ -595,6 +594,76 @@ def _construct_sdk_resource_from_gapic( # TODO(b/144545165): Improve documentation for list filtering once available # TODO(b/184910159): Expose `page_size` field in list method + @classmethod + def _list( + cls, + cls_filter: Callable[[proto.Message], bool] = lambda _: True, + filter: Optional[str] = None, + order_by: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> List[AiPlatformResourceNoun]: + + self = cls._empty_constructor( + project=project, location=location, credentials=credentials + ) + + # Fetch credentials once and re-use for all `_empty_constructor()` calls + creds = initializer.global_config.credentials + + resource_list_method = getattr(self.api_client, self._list_method) + + list_request = { + "parent": initializer.global_config.common_location_path(), + "filter": filter, + } + + if order_by: + list_request["order_by"] = order_by + + resource_list = resource_list_method(request=list_request) or [] + + return [ + self._construct_sdk_resource_from_gapic( + gapic_resource, project=project, location=location, credentials=creds + ) + for gapic_resource in resource_list + if cls_filter(gapic_resource) + ] + + @classmethod + def _list_with_local_order( + cls, + cls_filter: Callable[[proto.Message], bool] = lambda _: True, + filter: Optional[str] = None, + order_by: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> List[AiPlatformResourceNoun]: + """Client-side sorting when list API doesn't support `order_by`""" + + li = cls._list( + cls_filter=cls_filter, + filter=filter, + order_by=None, # This method will handle the ordering locally + project=project, + location=location, + credentials=credentials, + ) + + desc = "desc" in order_by + order_by = order_by.replace("desc", "") + order_by = order_by.split(",") + + li.sort( + key=lambda x: tuple(getattr(x, field.strip()) for field in order_by), + reverse=desc, + ) + + return li + @classmethod def list( cls, @@ -633,80 +702,17 @@ def list( credentials set in aiplatform.init. Returns: - Sequence[AiPlatformResourceNoun] - A list of SDK resource objects + List[AiPlatformResourceNoun] - A list of SDK resource objects """ - _UNSUPPORTED_LIST_ORDER_BY_TYPES = ( - aiplatform.jobs._Job, - aiplatform.models.Endpoint, - aiplatform.models.Model, - aiplatform.training_jobs._TrainingJob, - ) - self = cls._empty_constructor( - project=project, location=location, credentials=credentials + return cls._list( + filter=filter, + order_by=order_by, + project=project, + location=location, + credentials=credentials, ) - creds = initializer.global_config.credentials - - resource_list_method = getattr(self.api_client, self._list_method) - order_locally = False - - list_request = { - "parent": initializer.global_config.common_location_path(), - "filter": filter, - } - - # If list method does not offer `order_by` field, order locally - if order_by and issubclass(type(self), _UNSUPPORTED_LIST_ORDER_BY_TYPES): - order_locally = True - elif order_by: - list_request["order_by"] = order_by - - resource_list = resource_list_method(request=list_request) or [] - - # Only return objects specific to the calling subclass, - # for example TabularDataset.list() only lists TabularDatasets - if issubclass(type(self), aiplatform.datasets.Dataset): - final_list = [ - self._construct_sdk_resource_from_gapic( - gapic_resource, credentials=creds - ) - for gapic_resource in resource_list - if gapic_resource.metadata_schema_uri - in self._supported_metadata_schema_uris - ] - - elif issubclass(type(self), aiplatform.training_jobs._TrainingJob): - final_list = [ - self._construct_sdk_resource_from_gapic( - gapic_resource, credentials=creds - ) - for gapic_resource in resource_list - if gapic_resource.training_task_definition - in self._supported_training_schemas - ] - - else: - final_list = [ - self._construct_sdk_resource_from_gapic( - gapic_resource, credentials=creds - ) - for gapic_resource in resource_list - ] - - # Client-side sorting when API doesn't support `order_by` - if order_locally: - desc = "desc" in order_by - order_by = order_by.replace("desc", "") - order_by = order_by.split(",") - - final_list.sort( - key=lambda x: tuple(getattr(x, field.strip()) for field in order_by), - reverse=desc, - ) - - return final_list - @optional_sync() def delete(self, sync: bool = True) -> None: """Deletes this AI Platform resource. WARNING: This deletion is permament. diff --git a/google/cloud/aiplatform/datasets/dataset.py b/google/cloud/aiplatform/datasets/dataset.py index 48ccd46c7f..f98523e6cc 100644 --- a/google/cloud/aiplatform/datasets/dataset.py +++ b/google/cloud/aiplatform/datasets/dataset.py @@ -15,7 +15,7 @@ # limitations under the License. # -from typing import Optional, Sequence, Dict, Tuple, Union +from typing import Optional, Sequence, Dict, Tuple, Union, List from google.api_core import operation from google.auth import credentials as auth_credentials @@ -492,3 +492,57 @@ def export_data(self, output_dir: str) -> Sequence[str]: def update(self): raise NotImplementedError("Update dataset has not been implemented yet") + + @classmethod + def list( + cls, + filter: Optional[str] = None, + order_by: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> List[base.AiPlatformResourceNoun]: + """List all instances of this Dataset resource. + + Example Usage: + + aiplatform.TabularDataset.list( + filter='labels.my_key="my_value"', + order_by='display_name' + ) + + Args: + filter (str): + Optional. An expression for filtering the results of the request. + For field names both snake_case and camelCase are supported. + order_by (str): + Optional. A comma-separated list of fields to order by, sorted in + ascending order. Use "desc" after a field name for descending. + Supported fields: `display_name`, `create_time`, `update_time` + project (str): + Optional. Project to retrieve list from. If not set, project + set in aiplatform.init will be used. + location (str): + Optional. Location to retrieve list from. If not set, location + set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to retrieve list. Overrides + credentials set in aiplatform.init. + + Returns: + List[base.AiPlatformResourceNoun] - A list of Dataset resource objects + """ + + dataset_subclass_filter = ( + lambda gapic_obj: gapic_obj.metadata_schema_uri + in cls._supported_metadata_schema_uris + ) + + return cls._list_with_local_order( + cls_filter=dataset_subclass_filter, + filter=filter, + order_by=order_by, + project=project, + location=location, + credentials=credentials, + ) diff --git a/google/cloud/aiplatform/jobs.py b/google/cloud/aiplatform/jobs.py index 4587c9281d..39778407f7 100644 --- a/google/cloud/aiplatform/jobs.py +++ b/google/cloud/aiplatform/jobs.py @@ -15,7 +15,7 @@ # limitations under the License. # -from typing import Iterable, Optional, Union, Sequence, Dict +from typing import Iterable, Optional, Union, Sequence, Dict, List import abc import sys @@ -166,6 +166,53 @@ def _block_until_complete(self): if self.state in _JOB_ERROR_STATES: raise RuntimeError("Job failed with:\n%s" % self._gca_resource.error) + @classmethod + def list( + cls, + filter: Optional[str] = None, + order_by: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> List[base.AiPlatformResourceNoun]: + """List all instances of this Job Resource. + + Example Usage: + + aiplatform.BatchPredictionJobs.list( + filter='state="JOB_STATE_SUCCEEDED" AND display_name="my_job"', + ) + + Args: + filter (str): + Optional. An expression for filtering the results of the request. + For field names both snake_case and camelCase are supported. + order_by (str): + Optional. A comma-separated list of fields to order by, sorted in + ascending order. Use "desc" after a field name for descending. + Supported fields: `display_name`, `create_time`, `update_time` + project (str): + Optional. Project to retrieve list from. If not set, project + set in aiplatform.init will be used. + location (str): + Optional. Location to retrieve list from. If not set, location + set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to retrieve list. Overrides + credentials set in aiplatform.init. + + Returns: + List[AiPlatformResourceNoun] - A list of Job resource objects + """ + + return cls._list_with_local_order( + filter=filter, + order_by=order_by, + project=project, + location=location, + credentials=credentials, + ) + def cancel(self) -> None: """Cancels this Job. Success of cancellation is not guaranteed. Use `Job.state` property to verify if cancellation was successful.""" diff --git a/google/cloud/aiplatform/models.py b/google/cloud/aiplatform/models.py index 23df01fc93..58ad2f3059 100644 --- a/google/cloud/aiplatform/models.py +++ b/google/cloud/aiplatform/models.py @@ -1031,6 +1031,53 @@ def explain( explanations=explain_response.explanations, ) + @classmethod + def list( + cls, + filter: Optional[str] = None, + order_by: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> List["aiplatform.Endpoint"]: + """List all Endpoint resource instances. + + Example Usage: + + aiplatform.Endpoint.list( + filter='labels.my_label="my_label_value" OR display_name=!"old_endpoint"', + ) + + Args: + filter (str): + Optional. An expression for filtering the results of the request. + For field names both snake_case and camelCase are supported. + order_by (str): + Optional. A comma-separated list of fields to order by, sorted in + ascending order. Use "desc" after a field name for descending. + Supported fields: `display_name`, `create_time`, `update_time` + project (str): + Optional. Project to retrieve list from. If not set, project + set in aiplatform.init will be used. + location (str): + Optional. Location to retrieve list from. If not set, location + set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to retrieve list. Overrides + credentials set in aiplatform.init. + + Returns: + List[aiplatform.Endpoint] - A list of Endpoint resource objects + """ + + return cls._list_with_local_order( + filter=filter, + order_by=order_by, + project=project, + location=location, + credentials=credentials, + ) + def list_models(self) -> Sequence[gca_endpoint.DeployedModel]: """Returns a list of the models deployed to this Endpoint. @@ -1831,3 +1878,50 @@ def batch_predict( encryption_spec_key_name=encryption_spec_key_name, sync=sync, ) + + @classmethod + def list( + cls, + filter: Optional[str] = None, + order_by: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> List["aiplatform.Model"]: + """List all Model resource instances. + + Example Usage: + + aiplatform.Model.list( + filter='labels.my_label="my_label_value" AND display_name="my_model"', + ) + + Args: + filter (str): + Optional. An expression for filtering the results of the request. + For field names both snake_case and camelCase are supported. + order_by (str): + Optional. A comma-separated list of fields to order by, sorted in + ascending order. Use "desc" after a field name for descending. + Supported fields: `display_name`, `create_time`, `update_time` + project (str): + Optional. Project to retrieve list from. If not set, project + set in aiplatform.init will be used. + location (str): + Optional. Location to retrieve list from. If not set, location + set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to retrieve list. Overrides + credentials set in aiplatform.init. + + Returns: + List[aiplatform.Model] - A list of Model resource objects + """ + + return cls._list_with_local_order( + filter=filter, + order_by=order_by, + project=project, + location=location, + credentials=credentials, + ) diff --git a/google/cloud/aiplatform/training_jobs.py b/google/cloud/aiplatform/training_jobs.py index 402a0d8b3e..bd914cf1e1 100644 --- a/google/cloud/aiplatform/training_jobs.py +++ b/google/cloud/aiplatform/training_jobs.py @@ -690,6 +690,60 @@ def _assert_has_run(self) -> bool: ) return False + @classmethod + def list( + cls, + filter: Optional[str] = None, + order_by: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> List["base.AiPlatformResourceNoune"]: + """List all instances of this TrainingJob resource. + + Example Usage: + + aiplatform.CustomTrainingJob.list( + filter='display_name="experiment_a27"', + order_by='create_time desc' + ) + + Args: + filter (str): + Optional. An expression for filtering the results of the request. + For field names both snake_case and camelCase are supported. + order_by (str): + Optional. A comma-separated list of fields to order by, sorted in + ascending order. Use "desc" after a field name for descending. + Supported fields: `display_name`, `create_time`, `update_time` + project (str): + Optional. Project to retrieve list from. If not set, project + set in aiplatform.init will be used. + location (str): + Optional. Location to retrieve list from. If not set, location + set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to retrieve list. Overrides + credentials set in aiplatform.init. + + Returns: + List[AiPlatformResourceNoun] - A list of TrainingJob resource objects + """ + + training_job_subclass_filter = ( + lambda gapic_obj: gapic_obj.training_task_definition + in cls._supported_training_schemas + ) + + return cls._list_with_local_order( + cls_filter=training_job_subclass_filter, + filter=filter, + order_by=order_by, + project=project, + location=location, + credentials=credentials, + ) + def cancel(self) -> None: """Starts asynchronous cancellation on the TrainingJob. The server makes a best effort to cancel the job, but success is not guaranteed. diff --git a/tests/unit/aiplatform/test_datasets.py b/tests/unit/aiplatform/test_datasets.py index 71f592c9b8..a0679f35de 100644 --- a/tests/unit/aiplatform/test_datasets.py +++ b/tests/unit/aiplatform/test_datasets.py @@ -782,11 +782,7 @@ def test_list_dataset(self, list_datasets_mock): ) list_datasets_mock.assert_called_once_with( - request={ - "parent": _TEST_PARENT, - "filter": _TEST_LIST_FILTER, - "order_by": _TEST_LIST_ORDER_BY, - } + request={"parent": _TEST_PARENT, "filter": _TEST_LIST_FILTER} ) # Ensure returned list is smaller since it filtered out non-tabular datasets From 693de0a7648c0bf79deb50c41007110cabe1179c Mon Sep 17 00:00:00 2001 From: Vinny Senthil Date: Fri, 9 Apr 2021 11:26:49 -0600 Subject: [PATCH 7/9] Resolve build issue from merge conflict --- google/cloud/aiplatform/base.py | 2 +- google/cloud/aiplatform/models.py | 9 ++++--- tests/unit/aiplatform/test_datasets.py | 16 +++++++----- tests/unit/aiplatform/test_models.py | 36 +++++++++++++++++++++++++- 4 files changed, 51 insertions(+), 12 deletions(-) diff --git a/google/cloud/aiplatform/base.py b/google/cloud/aiplatform/base.py index 01633471bd..de0d3dcd0a 100644 --- a/google/cloud/aiplatform/base.py +++ b/google/cloud/aiplatform/base.py @@ -22,7 +22,7 @@ import inspect import proto import threading -from typing import Any, Callable, Dict, List, Optional, Sequence, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union from google.auth import credentials as auth_credentials from google.cloud.aiplatform import initializer diff --git a/google/cloud/aiplatform/models.py b/google/cloud/aiplatform/models.py index debae9314b..8fba59f7e8 100644 --- a/google/cloud/aiplatform/models.py +++ b/google/cloud/aiplatform/models.py @@ -24,6 +24,7 @@ from google.cloud.aiplatform import explain from google.cloud.aiplatform import initializer from google.cloud.aiplatform import jobs +from google.cloud.aiplatform import models from google.cloud.aiplatform import utils from google.cloud.aiplatform.compat.services import endpoint_service_client @@ -1059,7 +1060,7 @@ def list( project: Optional[str] = None, location: Optional[str] = None, credentials: Optional[auth_credentials.Credentials] = None, - ) -> List["aiplatform.Endpoint"]: + ) -> List["models.Endpoint"]: """List all Endpoint resource instances. Example Usage: @@ -1087,7 +1088,7 @@ def list( credentials set in aiplatform.init. Returns: - List[aiplatform.Endpoint] - A list of Endpoint resource objects + List[models.Endpoint] - A list of Endpoint resource objects """ return cls._list_with_local_order( @@ -1913,7 +1914,7 @@ def list( project: Optional[str] = None, location: Optional[str] = None, credentials: Optional[auth_credentials.Credentials] = None, - ) -> List["aiplatform.Model"]: + ) -> List["models.Model"]: """List all Model resource instances. Example Usage: @@ -1941,7 +1942,7 @@ def list( credentials set in aiplatform.init. Returns: - List[aiplatform.Model] - A list of Model resource objects + List[models.Model] - A list of Model resource objects """ return cls._list_with_local_order( diff --git a/tests/unit/aiplatform/test_datasets.py b/tests/unit/aiplatform/test_datasets.py index 8f86fa5fcd..b9d8e45ed2 100644 --- a/tests/unit/aiplatform/test_datasets.py +++ b/tests/unit/aiplatform/test_datasets.py @@ -110,17 +110,19 @@ _TEST_OUTPUT_DIR = "gs://my-output-bucket" _TEST_DATASET_LIST = [ - GapicDataset( + gca_dataset.Dataset( display_name="a", metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TABULAR ), - GapicDataset( + gca_dataset.Dataset( display_name="d", metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_NONTABULAR ), - GapicDataset( + gca_dataset.Dataset( display_name="b", metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TABULAR ), - GapicDataset(display_name="e", metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TEXT), - GapicDataset( + gca_dataset.Dataset( + display_name="e", metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TEXT + ), + gca_dataset.Dataset( display_name="c", metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_TABULAR ), ] @@ -266,7 +268,9 @@ def export_data_mock(): @pytest.fixture def list_datasets_mock(): - with patch.object(DatasetServiceClient, "list_datasets") as list_datasets_mock: + with patch.object( + dataset_service_client.DatasetServiceClient, "list_datasets" + ) as list_datasets_mock: list_datasets_mock.return_value = _TEST_DATASET_LIST yield list_datasets_mock diff --git a/tests/unit/aiplatform/test_models.py b/tests/unit/aiplatform/test_models.py index 3b4c434f3b..7994bda8e7 100644 --- a/tests/unit/aiplatform/test_models.py +++ b/tests/unit/aiplatform/test_models.py @@ -302,7 +302,9 @@ def delete_model_mock(): @pytest.fixture def list_models_mock(): - with mock.patch.object(ModelServiceClient, "list_models") as list_models_mock: + with mock.patch.object( + model_service_client.ModelServiceClient, "list_models" + ) as list_models_mock: list_models_mock.return_value = _TEST_MODEL_LIST yield list_models_mock @@ -588,7 +590,39 @@ def test_upload_uploads_and_gets_model_with_custom_project( ): aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + test_model_resource_name = model_service_client.ModelServiceClient.model_path( + _TEST_PROJECT_2, _TEST_LOCATION, _TEST_ID + ) + + my_model = models.Model.upload( + display_name=_TEST_MODEL_NAME, + artifact_uri=_TEST_ARTIFACT_URI, + serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + project=_TEST_PROJECT_2, + sync=sync, + ) + + if not sync: + my_model.wait() + + container_spec = gca_model.ModelContainerSpec( + image_uri=_TEST_SERVING_CONTAINER_IMAGE, + predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + ) + + managed_model = gca_model.Model( + display_name=_TEST_MODEL_NAME, + artifact_uri=_TEST_ARTIFACT_URI, + container_spec=container_spec, + ) + + upload_model_with_custom_project_mock.assert_called_once_with( parent=f"projects/{_TEST_PROJECT_2}/locations/{_TEST_LOCATION}", + model=managed_model, ) get_model_with_custom_project_mock.assert_called_once_with( From 885c8d67fc1cb864d937803ed959b370e3f7afd8 Mon Sep 17 00:00:00 2001 From: Vinny Senthil Date: Sat, 10 Apr 2021 13:34:02 -0600 Subject: [PATCH 8/9] Add doc strings, address review comments --- google/cloud/aiplatform/base.py | 84 +++++++++++++++++++++++++++++-- google/cloud/aiplatform/models.py | 2 +- 2 files changed, 82 insertions(+), 4 deletions(-) diff --git a/google/cloud/aiplatform/base.py b/google/cloud/aiplatform/base.py index de0d3dcd0a..ed5a47dc3b 100644 --- a/google/cloud/aiplatform/base.py +++ b/google/cloud/aiplatform/base.py @@ -585,7 +585,26 @@ def _construct_sdk_resource_from_gapic( location: Optional[str] = None, credentials: Optional[auth_credentials.Credentials] = None, ) -> AiPlatformResourceNoun: - """Given a GAPIC object, return the SDK representation.""" + """Given a GAPIC resource object, return the SDK representation. + + Args: + gapic_resource (proto.Message): + A GAPIC representation of an AI Platform resource, usually + retrieved by a get_* or in a list_* API call. + project (str): + Optional. Project to construct SDK object from. If not set, + project set in aiplatform.init will be used. + location (str): + Optional. Location to construct SDK object from. If not set, + location set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to construct SDK object. + Overrides credentials set in aiplatform.init. + + Returns: + AiPlatformResourceNoun: + An initialized SDK object that represents GAPIC type. + """ sdk_resource = self._empty_constructor( project=project, location=location, credentials=credentials ) @@ -604,7 +623,35 @@ def _list( location: Optional[str] = None, credentials: Optional[auth_credentials.Credentials] = None, ) -> List[AiPlatformResourceNoun]: + """Private method to list all instances of this AI Platform Resource, + takes a `cls_filter` arg to filter to a particular SDK resource subclass. + + Args: + cls_filter (Callable[[proto.Message], bool]): + A function that takes one argument, a GAPIC resource, and returns + a bool. If the function returns False, that resource will be + excluded from the returned list. Example usage: + cls_filter = lambda obj: obj.metadata in cls.valid_metadatas + filter (str): + Optional. An expression for filtering the results of the request. + For field names both snake_case and camelCase are supported. + order_by (str): + Optional. A comma-separated list of fields to order by, sorted in + ascending order. Use "desc" after a field name for descending. + Supported fields: `display_name`, `create_time`, `update_time` + project (str): + Optional. Project to retrieve list from. If not set, project + set in aiplatform.init will be used. + location (str): + Optional. Location to retrieve list from. If not set, location + set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to retrieve list. Overrides + credentials set in aiplatform.init. + Returns: + List[AiPlatformResourceNoun] - A list of SDK resource objects + """ self = cls._empty_constructor( project=project, location=location, credentials=credentials ) @@ -615,7 +662,9 @@ def _list( resource_list_method = getattr(self.api_client, self._list_method) list_request = { - "parent": initializer.global_config.common_location_path(), + "parent": initializer.global_config.common_location_path( + project=project, location=location + ), "filter": filter, } @@ -642,7 +691,36 @@ def _list_with_local_order( location: Optional[str] = None, credentials: Optional[auth_credentials.Credentials] = None, ) -> List[AiPlatformResourceNoun]: - """Client-side sorting when list API doesn't support `order_by`""" + """Private method to list all instances of this AI Platform Resource, + takes a `cls_filter` arg to filter to a particular SDK resource subclass. + Provides client-side sorting when a list API doesn't support `order_by`. + + Args: + cls_filter (Callable[[proto.Message], bool]): + A function that takes one argument, a GAPIC resource, and returns + a bool. If the function returns False, that resource will be + excluded from the returned list. Example usage: + cls_filter = lambda obj: obj.metadata in cls.valid_metadatas + filter (str): + Optional. An expression for filtering the results of the request. + For field names both snake_case and camelCase are supported. + order_by (str): + Optional. A comma-separated list of fields to order by, sorted in + ascending order. Use "desc" after a field name for descending. + Supported fields: `display_name`, `create_time`, `update_time` + project (str): + Optional. Project to retrieve list from. If not set, project + set in aiplatform.init will be used. + location (str): + Optional. Location to retrieve list from. If not set, location + set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to retrieve list. Overrides + credentials set in aiplatform.init. + + Returns: + List[AiPlatformResourceNoun] - A list of SDK resource objects + """ li = cls._list( cls_filter=cls_filter, diff --git a/google/cloud/aiplatform/models.py b/google/cloud/aiplatform/models.py index 8fba59f7e8..9fb89e4e89 100644 --- a/google/cloud/aiplatform/models.py +++ b/google/cloud/aiplatform/models.py @@ -1945,7 +1945,7 @@ def list( List[models.Model] - A list of Model resource objects """ - return cls._list_with_local_order( + return cls._list( filter=filter, order_by=order_by, project=project, From 9d5471177f2c230ba3b7dc102beabbb800d22b42 Mon Sep 17 00:00:00 2001 From: Vinny Senthil Date: Sat, 10 Apr 2021 23:07:59 -0600 Subject: [PATCH 9/9] Move Model.list tests to Endpoint.list to test local sorting and filtering --- tests/unit/aiplatform/test_endpoints.py | 68 +++++++++++++++++++++++++ tests/unit/aiplatform/test_models.py | 66 ------------------------ 2 files changed, 68 insertions(+), 66 deletions(-) diff --git a/tests/unit/aiplatform/test_endpoints.py b/tests/unit/aiplatform/test_endpoints.py index 7b18e1e497..ea74c89e5e 100644 --- a/tests/unit/aiplatform/test_endpoints.py +++ b/tests/unit/aiplatform/test_endpoints.py @@ -19,6 +19,7 @@ from unittest import mock from importlib import reload +from datetime import datetime, timedelta from google.api_core import operation as ga_operation from google.auth import credentials as auth_credentials @@ -138,6 +139,23 @@ ) +_TEST_ENDPOINT_LIST = [ + gca_endpoint.Endpoint( + display_name="aac", create_time=datetime.now() - timedelta(minutes=15) + ), + gca_endpoint.Endpoint( + display_name="aab", create_time=datetime.now() - timedelta(minutes=5) + ), + gca_endpoint.Endpoint( + display_name="aaa", create_time=datetime.now() - timedelta(minutes=10) + ), +] + +_TEST_LIST_FILTER = 'display_name="abc"' +_TEST_LIST_ORDER_BY_CREATE_TIME = "create_time desc" +_TEST_LIST_ORDER_BY_DISPLAY_NAME = "display_name" + + @pytest.fixture def get_endpoint_mock(): with mock.patch.object( @@ -264,6 +282,15 @@ def sdk_undeploy_all_mock(): yield sdk_undeploy_all_mock +@pytest.fixture +def list_endpoints_mock(): + with mock.patch.object( + endpoint_service_client.EndpointServiceClient, "list_endpoints" + ) as list_endpoints_mock: + list_endpoints_mock.return_value = _TEST_ENDPOINT_LIST + yield list_endpoints_mock + + @pytest.fixture def create_client_mock(): with mock.patch.object( @@ -307,6 +334,7 @@ class TestEndpoint: def setup_method(self): reload(initializer) reload(aiplatform) + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) def teardown_method(self): initializer.global_pool.shutdown(wait=True) @@ -974,6 +1002,46 @@ def test_undeploy_all(self, sdk_private_undeploy_mock, sync): any_order=True, ) + def test_list_endpoint_order_by_time(self, list_endpoints_mock): + """Test call to Endpoint.list() and ensure list is returned in descending order of create_time""" + + ep_list = aiplatform.Endpoint.list( + filter=_TEST_LIST_FILTER, order_by=_TEST_LIST_ORDER_BY_CREATE_TIME + ) + + # `order_by` is not passed to API since it is not an accepted field + list_endpoints_mock.assert_called_once_with( + request={"parent": _TEST_PARENT, "filter": _TEST_LIST_FILTER} + ) + + assert len(ep_list) == len(_TEST_ENDPOINT_LIST) + + for ep in ep_list: + assert type(ep) == aiplatform.Endpoint + + assert ep_list[0].create_time > ep_list[1].create_time > ep_list[2].create_time + + def test_list_endpoint_order_by_display_name(self, list_endpoints_mock): + """Test call to Endpoint.list() and ensure list is returned in order of display_name""" + + ep_list = aiplatform.Endpoint.list( + filter=_TEST_LIST_FILTER, order_by=_TEST_LIST_ORDER_BY_DISPLAY_NAME + ) + + # `order_by` is not passed to API since it is not an accepted field + list_endpoints_mock.assert_called_once_with( + request={"parent": _TEST_PARENT, "filter": _TEST_LIST_FILTER} + ) + + assert len(ep_list) == len(_TEST_ENDPOINT_LIST) + + for ep in ep_list: + assert type(ep) == aiplatform.Endpoint + + assert ( + ep_list[0].display_name < ep_list[1].display_name < ep_list[2].display_name + ) + @pytest.mark.usefixtures("get_endpoint_with_models_mock") @pytest.mark.parametrize("sync", [True, False]) def test_delete_endpoint_without_force( diff --git a/tests/unit/aiplatform/test_models.py b/tests/unit/aiplatform/test_models.py index 7994bda8e7..47b000d189 100644 --- a/tests/unit/aiplatform/test_models.py +++ b/tests/unit/aiplatform/test_models.py @@ -19,7 +19,6 @@ from concurrent import futures import pytest from unittest import mock -from datetime import datetime, timedelta from google.api_core import operation as ga_operation from google.auth import credentials as auth_credentials @@ -158,22 +157,6 @@ _TEST_OUTPUT_DIR = "gs://my-output-bucket" -_TEST_MODEL_LIST = [ - gca_model.Model( - display_name="aac", create_time=datetime.now() - timedelta(minutes=15) - ), - gca_model.Model( - display_name="aab", create_time=datetime.now() - timedelta(minutes=5) - ), - gca_model.Model( - display_name="aaa", create_time=datetime.now() - timedelta(minutes=10) - ), -] - -_TEST_LIST_FILTER = 'display_name="abc"' -_TEST_LIST_ORDER_BY_CREATE_TIME = "create_time desc" -_TEST_LIST_ORDER_BY_DISPLAY_NAME = "display_name" - @pytest.fixture def get_endpoint_mock(): @@ -300,15 +283,6 @@ def delete_model_mock(): yield delete_model_mock -@pytest.fixture -def list_models_mock(): - with mock.patch.object( - model_service_client.ModelServiceClient, "list_models" - ) as list_models_mock: - list_models_mock.return_value = _TEST_MODEL_LIST - yield list_models_mock - - @pytest.fixture def deploy_model_mock(): with mock.patch.object( @@ -1123,46 +1097,6 @@ def test_delete_model(self, delete_model_mock, sync): delete_model_mock.assert_called_once_with(name=test_model.resource_name) - def test_list_model_order_by_time(self, list_models_mock): - """Test call to Model.list() and ensure list is returned in descending order of create_time""" - - ds_list = aiplatform.Model.list( - filter=_TEST_LIST_FILTER, order_by=_TEST_LIST_ORDER_BY_CREATE_TIME - ) - - # `order_by` is not passed to API since it is not an accepted field - list_models_mock.assert_called_once_with( - request={"parent": _TEST_PARENT, "filter": _TEST_LIST_FILTER} - ) - - assert len(ds_list) == len(_TEST_MODEL_LIST) - - for ds in ds_list: - assert type(ds) == aiplatform.Model - - assert ds_list[0].create_time > ds_list[1].create_time > ds_list[2].create_time - - def test_list_model_order_by_display_name(self, list_models_mock): - """Test call to Model.list() and ensure list is returned in order of display_name""" - - ds_list = aiplatform.Model.list( - filter=_TEST_LIST_FILTER, order_by=_TEST_LIST_ORDER_BY_DISPLAY_NAME - ) - - # `order_by` is not passed to API since it is not an accepted field - list_models_mock.assert_called_once_with( - request={"parent": _TEST_PARENT, "filter": _TEST_LIST_FILTER} - ) - - assert len(ds_list) == len(_TEST_MODEL_LIST) - - for ds in ds_list: - assert type(ds) == aiplatform.Model - - assert ( - ds_list[0].display_name < ds_list[1].display_name < ds_list[2].display_name - ) - @pytest.mark.usefixtures("get_model_mock") def test_print_model(self): aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)