Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add list() method to all resource nouns #294

Merged
merged 12 commits into from
Apr 11, 2021
158 changes: 155 additions & 3 deletions google/cloud/aiplatform/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@

import abc
from concurrent import futures
import datetime
import functools
import inspect
import threading
from typing import Any, Callable, Dict, Optional, Sequence, Type, Union

import proto
import threading
from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union

from google.auth import credentials as auth_credentials
from google.cloud.aiplatform import initializer
Expand Down Expand Up @@ -249,6 +249,12 @@ def _getter_method(cls) -> str:
"""Name of getter method of client class for retrieving the resource."""
pass

@property
@abc.abstractmethod
def _list_method(cls) -> str:
"""Name of list method of client class for listing resources."""
pass

@property
@abc.abstractmethod
def _delete_method(cls) -> str:
Expand Down Expand Up @@ -343,6 +349,17 @@ def display_name(self) -> str:
"""Display name of this resource."""
return self._gca_resource.display_name

@property
def create_time(self) -> datetime.datetime:
"""Time this resource was created."""
return self._gca_resource.create_time

@property
def update_time(self) -> datetime.datetime:
vinnysenthil marked this conversation as resolved.
Show resolved Hide resolved
"""Time this resource was last updated."""
self._sync_gca_resource()
return self._gca_resource.update_time

def __repr__(self) -> str:
return f"{object.__repr__(self)} \nresource name: {self.resource_name}"

Expand Down Expand Up @@ -561,6 +578,141 @@ def _sync_object_with_future_result(
if value:
setattr(self, attribute, value)

def _construct_sdk_resource_from_gapic(
self,
gapic_resource: proto.Message,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> AiPlatformResourceNoun:
"""Given a GAPIC object, return the SDK representation."""
vinnysenthil marked this conversation as resolved.
Show resolved Hide resolved
sdk_resource = self._empty_constructor(
project=project, location=location, credentials=credentials
)
sdk_resource._gca_resource = gapic_resource
return sdk_resource

# TODO(b/144545165): Improve documentation for list filtering once available
# TODO(b/184910159): Expose `page_size` field in list method
@classmethod
def _list(
cls,
cls_filter: Callable[[proto.Message], bool] = lambda _: True,
filter: Optional[str] = None,
order_by: Optional[str] = None,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> List[AiPlatformResourceNoun]:
vinnysenthil marked this conversation as resolved.
Show resolved Hide resolved

self = cls._empty_constructor(
project=project, location=location, credentials=credentials
)

# Fetch credentials once and re-use for all `_empty_constructor()` calls
creds = initializer.global_config.credentials

resource_list_method = getattr(self.api_client, self._list_method)

list_request = {
"parent": initializer.global_config.common_location_path(),
vinnysenthil marked this conversation as resolved.
Show resolved Hide resolved
"filter": filter,
}

if order_by:
list_request["order_by"] = order_by

resource_list = resource_list_method(request=list_request) or []

return [
self._construct_sdk_resource_from_gapic(
gapic_resource, project=project, location=location, credentials=creds
)
for gapic_resource in resource_list
if cls_filter(gapic_resource)
]

@classmethod
def _list_with_local_order(
cls,
cls_filter: Callable[[proto.Message], bool] = lambda _: True,
filter: Optional[str] = None,
order_by: Optional[str] = None,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> List[AiPlatformResourceNoun]:
"""Client-side sorting when list API doesn't support `order_by`"""
vinnysenthil marked this conversation as resolved.
Show resolved Hide resolved

li = cls._list(
cls_filter=cls_filter,
filter=filter,
order_by=None, # This method will handle the ordering locally
project=project,
location=location,
credentials=credentials,
)

desc = "desc" in order_by
order_by = order_by.replace("desc", "")
order_by = order_by.split(",")

li.sort(
key=lambda x: tuple(getattr(x, field.strip()) for field in order_by),
reverse=desc,
)

return li

@classmethod
def list(
vinnysenthil marked this conversation as resolved.
Show resolved Hide resolved
cls,
filter: Optional[str] = None,
order_by: Optional[str] = None,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> List[AiPlatformResourceNoun]:
"""List all instances of this AI Platform Resource.

Example Usage:

aiplatform.BatchPredictionJobs.list(
filter='state="JOB_STATE_SUCCEEDED" AND display_name="my_job"',
)

aiplatform.Model.list(order_by="create_time desc, display_name")

Args:
filter (str):
Optional. An expression for filtering the results of the request.
For field names both snake_case and camelCase are supported.
order_by (str):
Optional. A comma-separated list of fields to order by, sorted in
ascending order. Use "desc" after a field name for descending.
sasha-gitg marked this conversation as resolved.
Show resolved Hide resolved
Supported fields: `display_name`, `create_time`, `update_time`
project (str):
Optional. Project to retrieve list from. If not set, project
set in aiplatform.init will be used.
location (str):
Optional. Location to retrieve list from. If not set, location
set in aiplatform.init will be used.
credentials (auth_credentials.Credentials):
Optional. Custom credentials to use to retrieve list. Overrides
credentials set in aiplatform.init.

Returns:
List[AiPlatformResourceNoun] - A list of SDK resource objects
"""

return cls._list(
filter=filter,
order_by=order_by,
project=project,
location=location,
credentials=credentials,
)

@optional_sync()
def delete(self, sync: bool = True) -> None:
"""Deletes this AI Platform resource. WARNING: This deletion is permament.
Expand Down
59 changes: 57 additions & 2 deletions google/cloud/aiplatform/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.
#

from typing import Optional, Sequence, Dict, Tuple, Union
from typing import Optional, Sequence, Dict, Tuple, Union, List

from google.api_core import operation
from google.auth import credentials as auth_credentials
Expand All @@ -40,9 +40,10 @@ class Dataset(base.AiPlatformResourceNounWithFutureManager):
_is_client_prediction_client = False
_resource_noun = "datasets"
_getter_method = "get_dataset"
_list_method = "list_datasets"
_delete_method = "delete_dataset"

_supported_metadata_schema_uris: Optional[Tuple[str]] = None
_supported_metadata_schema_uris: Tuple[str] = ()

def __init__(
self,
Expand Down Expand Up @@ -491,3 +492,57 @@ def export_data(self, output_dir: str) -> Sequence[str]:

def update(self):
raise NotImplementedError("Update dataset has not been implemented yet")

@classmethod
def list(
cls,
filter: Optional[str] = None,
order_by: Optional[str] = None,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> List[base.AiPlatformResourceNoun]:
"""List all instances of this Dataset resource.

Example Usage:

aiplatform.TabularDataset.list(
filter='labels.my_key="my_value"',
order_by='display_name'
)

Args:
filter (str):
Optional. An expression for filtering the results of the request.
For field names both snake_case and camelCase are supported.
order_by (str):
Optional. A comma-separated list of fields to order by, sorted in
ascending order. Use "desc" after a field name for descending.
Supported fields: `display_name`, `create_time`, `update_time`
project (str):
Optional. Project to retrieve list from. If not set, project
set in aiplatform.init will be used.
location (str):
Optional. Location to retrieve list from. If not set, location
set in aiplatform.init will be used.
credentials (auth_credentials.Credentials):
Optional. Custom credentials to use to retrieve list. Overrides
credentials set in aiplatform.init.

Returns:
List[base.AiPlatformResourceNoun] - A list of Dataset resource objects
"""

dataset_subclass_filter = (
lambda gapic_obj: gapic_obj.metadata_schema_uri
in cls._supported_metadata_schema_uris
)

return cls._list_with_local_order(
cls_filter=dataset_subclass_filter,
filter=filter,
order_by=order_by,
project=project,
location=location,
credentials=credentials,
)
53 changes: 52 additions & 1 deletion google/cloud/aiplatform/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.
#

from typing import Iterable, Optional, Union, Sequence, Dict
from typing import Iterable, Optional, Union, Sequence, Dict, List

import abc
import sys
Expand Down Expand Up @@ -168,6 +168,53 @@ def _block_until_complete(self):
if self.state in _JOB_ERROR_STATES:
raise RuntimeError("Job failed with:\n%s" % self._gca_resource.error)

@classmethod
def list(
cls,
filter: Optional[str] = None,
order_by: Optional[str] = None,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> List[base.AiPlatformResourceNoun]:
"""List all instances of this Job Resource.

Example Usage:

aiplatform.BatchPredictionJobs.list(
filter='state="JOB_STATE_SUCCEEDED" AND display_name="my_job"',
)

Args:
filter (str):
Optional. An expression for filtering the results of the request.
For field names both snake_case and camelCase are supported.
order_by (str):
Optional. A comma-separated list of fields to order by, sorted in
ascending order. Use "desc" after a field name for descending.
Supported fields: `display_name`, `create_time`, `update_time`
project (str):
Optional. Project to retrieve list from. If not set, project
set in aiplatform.init will be used.
location (str):
Optional. Location to retrieve list from. If not set, location
set in aiplatform.init will be used.
credentials (auth_credentials.Credentials):
Optional. Custom credentials to use to retrieve list. Overrides
credentials set in aiplatform.init.

Returns:
List[AiPlatformResourceNoun] - A list of Job resource objects
"""

return cls._list_with_local_order(
filter=filter,
order_by=order_by,
project=project,
location=location,
credentials=credentials,
)

def cancel(self) -> None:
"""Cancels this Job. Success of cancellation is not guaranteed. Use `Job.state`
property to verify if cancellation was successful."""
Expand All @@ -178,6 +225,7 @@ class BatchPredictionJob(_Job):

_resource_noun = "batchPredictionJobs"
_getter_method = "get_batch_prediction_job"
_list_method = "list_batch_prediction_jobs"
_cancel_method = "cancel_batch_prediction_job"
_delete_method = "delete_batch_prediction_job"
_job_type = "batch-predictions"
Expand Down Expand Up @@ -699,6 +747,7 @@ def iter_outputs(
class CustomJob(_Job):
_resource_noun = "customJobs"
_getter_method = "get_custom_job"
_list_method = "list_custom_job"
_cancel_method = "cancel_custom_job"
_delete_method = "delete_custom_job"
_job_type = "training"
Expand All @@ -708,6 +757,7 @@ class CustomJob(_Job):
class DataLabelingJob(_Job):
_resource_noun = "dataLabelingJobs"
_getter_method = "get_data_labeling_job"
_list_method = "list_data_labeling_jobs"
_cancel_method = "cancel_data_labeling_job"
_delete_method = "delete_data_labeling_job"
_job_type = "labeling-tasks"
Expand All @@ -717,6 +767,7 @@ class DataLabelingJob(_Job):
class HyperparameterTuningJob(_Job):
_resource_noun = "hyperparameterTuningJobs"
_getter_method = "get_hyperparameter_tuning_job"
_list_method = "list_hyperparameter_tuning_jobs"
_cancel_method = "cancel_hyperparameter_tuning_job"
_delete_method = "delete_hyperparameter_tuning_job"
pass
Loading