Skip to content

Commit

Permalink
feat: parse project location when passed full resource name to get apis
Browse files Browse the repository at this point in the history
  • Loading branch information
morgandu committed Apr 8, 2021
1 parent 10b89e2 commit 98612c2
Show file tree
Hide file tree
Showing 7 changed files with 156 additions and 9 deletions.
62 changes: 59 additions & 3 deletions google/cloud/aiplatform/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import functools
import inspect
import threading
from typing import Any, Callable, Dict, Optional, Sequence, Type, Union
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union

import proto

Expand Down Expand Up @@ -266,6 +266,7 @@ def __init__(
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
resource_name: Optional[str] = None,
):
"""Initializes class with project, location, and api_client.
Expand All @@ -274,8 +275,14 @@ def __init__(
location(str): The location of the resource noun.
credentials(google.auth.crendentials.Crendentials): Optional custom
credentials to use when accessing interacting with resource noun.
resource_name(str): A fully-qualified resource name or ID.
"""

if resource_name:
project, location = self._get_and_validate_project_location(
resource_name=resource_name, project=project, location=location
)

self.project = project or initializer.global_config.project
self.location = location or initializer.global_config.location
self.credentials = credentials or initializer.global_config.credentials
Expand Down Expand Up @@ -306,6 +313,41 @@ def _instantiate_client(
prediction_client=cls._is_client_prediction_client,
)

def _get_and_validate_project_location(
self,
resource_name: str,
project: Optional[str] = None,
location: Optional[str] = None,
) -> Tuple:

"""Validate the project and location for the resource.
Args:
resource_name(str): Required. A fully-qualified resource name or ID.
project(str): Project of the resource noun.
location(str): The location of the resource noun.
Raises:
RuntimeError if location is different from resource location
"""

if not project and not location:
return project, location

fields = utils.extract_fields_from_resource_name(
resource_name, self._resource_noun
)
if not fields:
return project, location

if location and fields.location != location:
raise RuntimeError(
f"location {location} is provided, but different from "
f"the resource location {fields.location}"
)

return fields.project, fields.location

def _get_gca_resource(self, resource_name: str) -> proto.Message:
"""Returns GAPIC service representation of client class resource."""
"""
Expand Down Expand Up @@ -493,6 +535,7 @@ def __init__(
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
resource_name: Optional[str] = None,
):
"""Initializes class with project, location, and api_client.
Expand All @@ -502,9 +545,14 @@ def __init__(
credentials(google.auth.crendentials.Crendentials):
Optional. custom credentials to use when accessing interacting with
resource noun.
resource_name(str): A fully-qualified resource name or ID.
"""
AiPlatformResourceNoun.__init__(
self, project=project, location=location, credentials=credentials
self,
project=project,
location=location,
credentials=credentials,
resource_name=resource_name,
)
FutureManager.__init__(self)

Expand All @@ -514,6 +562,7 @@ def _empty_constructor(
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
resource_name: Optional[str] = None,
) -> "AiPlatformResourceNounWithFutureManager":
"""Initializes with all attributes set to None.
Expand All @@ -526,11 +575,18 @@ def _empty_constructor(
credentials(google.auth.crendentials.Crendentials):
Optional. custom credentials to use when accessing interacting with
resource noun.
resource_name(str): A fully-qualified resource name or ID.
Returns:
An instance of this class with attributes set to None.
"""
self = cls.__new__(cls)
AiPlatformResourceNoun.__init__(self, project, location, credentials)
AiPlatformResourceNoun.__init__(
self,
project=project,
location=location,
credentials=credentials,
resource_name=resource_name,
)
FutureManager.__init__(self)
self._gca_resource = None
return self
Expand Down
5 changes: 4 additions & 1 deletion google/cloud/aiplatform/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,10 @@ def __init__(
"""

super().__init__(
project=project, location=location, credentials=credentials,
project=project,
location=location,
credentials=credentials,
resource_name=dataset_name,
)
self._gca_resource = self._get_gca_resource(resource_name=dataset_name)
self._validate_metadata_schema_uri()
Expand Down
7 changes: 6 additions & 1 deletion google/cloud/aiplatform/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,12 @@ def __init__(
Custom credentials to use. If not set, credentials set in
aiplatform.init will be used.
"""
super().__init__(project=project, location=location, credentials=credentials)
super().__init__(
project=project,
location=location,
credentials=credentials,
resource_name=job_name,
)
self._gca_resource = self._get_gca_resource(resource_name=job_name)

@property
Expand Down
14 changes: 12 additions & 2 deletions google/cloud/aiplatform/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,12 @@ def __init__(
credentials set in aiplatform.init.
"""

super().__init__(project=project, location=location, credentials=credentials)
super().__init__(
project=project,
location=location,
credentials=credentials,
resource_name=endpoint_name,
)
self._gca_resource = self._get_gca_resource(resource_name=endpoint_name)
self._prediction_client = self._instantiate_prediction_client(
location=location or initializer.global_config.location,
Expand Down Expand Up @@ -1144,7 +1149,12 @@ def __init__(
credentials set in aiplatform.init will be used.
"""

super().__init__(project=project, location=location, credentials=credentials)
super().__init__(
project=project,
location=location,
credentials=credentials,
resource_name=model_name,
)
self._gca_resource = self._get_gca_resource(resource_name=model_name)

# TODO(b/170979552) Add support for predict schemata
Expand Down
5 changes: 4 additions & 1 deletion google/cloud/aiplatform/training_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,10 @@ def get(
# These parameters won't be used as user can not run the job again.
# If they try, an exception will be raised.
self = cls._empty_constructor(
project=project, location=location, credentials=credentials
project=project,
location=location,
credentials=credentials,
resource_name=resource_name,
)

self._gca_resource = self._get_gca_resource(resource_name=resource_name)
Expand Down
33 changes: 33 additions & 0 deletions tests/unit/aiplatform/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
_TEST_PROJECT = "test-project"
_TEST_LOCATION = "us-central1"
_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}"
_TEST_ALT_PROJECT = "test-project_alt"

_TEST_ALT_LOCATION = "europe-west4"
_TEST_INVALID_LOCATION = "us-central2"
Expand Down Expand Up @@ -259,6 +260,38 @@ def test_init_dataset(self, get_dataset_mock):
datasets.Dataset(dataset_name=_TEST_NAME)
get_dataset_mock.assert_called_once_with(name=_TEST_NAME)

def test_init_dataset_with_id_only_with_project_and_location(
self, get_dataset_mock
):
aiplatform.init(project=_TEST_PROJECT)
datasets.Dataset(
dataset_name=_TEST_ID, project=_TEST_PROJECT, location=_TEST_LOCATION
)
get_dataset_mock.assert_called_once_with(name=_TEST_NAME)

def test_init_dataset_with_project_and_location(self, get_dataset_mock):
aiplatform.init(project=_TEST_PROJECT)
datasets.Dataset(
dataset_name=_TEST_NAME, project=_TEST_PROJECT, location=_TEST_LOCATION
)
get_dataset_mock.assert_called_once_with(name=_TEST_NAME)

def test_init_dataset_with_alt_project_and_location(self, get_dataset_mock):
aiplatform.init(project=_TEST_PROJECT)
datasets.Dataset(
dataset_name=_TEST_NAME, project=_TEST_ALT_PROJECT, location=_TEST_LOCATION
)
get_dataset_mock.assert_called_once_with(name=_TEST_NAME)

def test_init_dataset_with_project_and_alt_location(self):
aiplatform.init(project=_TEST_PROJECT)
with pytest.raises(RuntimeError):
datasets.Dataset(
dataset_name=_TEST_NAME,
project=_TEST_PROJECT,
location=_TEST_ALT_LOCATION,
)

def test_init_dataset_with_id_only(self, get_dataset_mock):
aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
datasets.Dataset(dataset_name=_TEST_ID)
Expand Down
39 changes: 38 additions & 1 deletion tests/unit/aiplatform/test_training_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@
_TEST_GCS_PATH_WITH_TRAILING_SLASH = f"{_TEST_GCS_PATH}/"
_TEST_LOCAL_SCRIPT_FILE_NAME = "____test____script.py"
_TEST_LOCAL_SCRIPT_FILE_PATH = f"path/to/{_TEST_LOCAL_SCRIPT_FILE_NAME}"
_TEST_PROJECT = "test-project"
_TEST_PYTHON_SOURCE = """
print('hello world')
"""
Expand Down Expand Up @@ -107,6 +106,8 @@
_TEST_NAME = (
f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/trainingPipelines/{_TEST_ID}"
)
_TEST_ALT_PROJECT = "test-project-alt"
_TEST_ALT_LOCATION = "europe-west4"

_TEST_MODEL_INSTANCE_SCHEMA_URI = "instance_schema_uri.yaml"
_TEST_MODEL_PARAMETERS_SCHEMA_URI = "parameters_schema_uri.yaml"
Expand Down Expand Up @@ -1381,6 +1382,42 @@ def test_get_training_job_with_id_only(self, get_training_job_custom_mock):
training_jobs.CustomTrainingJob.get(resource_name=_TEST_ID)
get_training_job_custom_mock.assert_called_once_with(name=_TEST_NAME)

def test_get_training_job_with_id_only_with_project_and_location(
self, get_training_job_custom_mock
):
aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
training_jobs.CustomTrainingJob.get(
resource_name=_TEST_ID, project=_TEST_PROJECT, location=_TEST_LOCATION
)
get_training_job_custom_mock.assert_called_once_with(name=_TEST_NAME)

def test_get_training_job_with_project_and_location(
self, get_training_job_custom_mock
):
aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
training_jobs.CustomTrainingJob.get(
resource_name=_TEST_NAME, project=_TEST_PROJECT, location=_TEST_LOCATION
)
get_training_job_custom_mock.assert_called_once_with(name=_TEST_NAME)

def test_get_training_job_with_alt_project_and_location(
self, get_training_job_custom_mock
):
aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
training_jobs.CustomTrainingJob.get(
resource_name=_TEST_NAME, project=_TEST_ALT_PROJECT, location=_TEST_LOCATION
)
get_training_job_custom_mock.assert_called_once_with(name=_TEST_NAME)

def test_get_training_job_with_project_and_alt_location(self):
aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
with pytest.raises(RuntimeError):
training_jobs.CustomTrainingJob.get(
resource_name=_TEST_NAME,
project=_TEST_PROJECT,
location=_TEST_ALT_LOCATION,
)

@pytest.mark.parametrize("sync", [True, False])
def test_run_call_pipeline_service_create_with_nontabular_dataset(
self,
Expand Down

0 comments on commit 98612c2

Please sign in to comment.