Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Enable MetadataStore to use credentials when aiplatfrom.init passed experiment and credentials. #460

Merged
merged 1 commit into from
Jun 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions google/cloud/aiplatform/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,19 @@ def init(
if metadata.metadata_service.experiment_name:
logging.info("project/location updated, reset Metadata config.")
metadata.metadata_service.reset()

if project:
self._project = project
if location:
utils.validate_region(location)
self._location = location
if staging_bucket:
self._staging_bucket = staging_bucket
if credentials:
self._credentials = credentials
if encryption_spec_key_name:
self._encryption_spec_key_name = encryption_spec_key_name

if experiment:
metadata.metadata_service.set_experiment(
experiment=experiment, description=experiment_description
Expand All @@ -105,12 +113,6 @@ def init(
raise ValueError(
"Experiment name needs to be set in `init` in order to add experiment descriptions."
)
if staging_bucket:
self._staging_bucket = staging_bucket
if credentials:
self._credentials = credentials
if encryption_spec_key_name:
self._encryption_spec_key_name = encryption_spec_key_name

def get_encryption_spec(
self,
Expand Down
2 changes: 1 addition & 1 deletion google/cloud/aiplatform/metadata/metadata_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def _get(
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> "Optional[_MetadataStore]":
) -> Optional["_MetadataStore"]:
"""Returns a MetadataStore resource.

Args:
Expand Down
77 changes: 77 additions & 0 deletions tests/unit/aiplatform/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@
#

from importlib import reload
from unittest import mock
from unittest.mock import patch, call

import pytest
from google.api_core import exceptions
from google.api_core import operation
from google.auth import credentials

from google.cloud import aiplatform
from google.cloud.aiplatform import initializer
Expand Down Expand Up @@ -106,6 +109,32 @@ def get_metadata_store_mock():
yield get_metadata_store_mock


@pytest.fixture
def get_metadata_store_mock_raise_not_found_exception():
with patch.object(
MetadataServiceClient, "get_metadata_store"
) as get_metadata_store_mock:
get_metadata_store_mock.side_effect = [
exceptions.NotFound("Test store not found."),
GapicMetadataStore(name=_TEST_METADATASTORE,),
]

yield get_metadata_store_mock


@pytest.fixture
def create_metadata_store_mock():
with patch.object(
MetadataServiceClient, "create_metadata_store"
) as create_metadata_store_mock:
create_metadata_store_lro_mock = mock.Mock(operation.Operation)
create_metadata_store_lro_mock.result.return_value = GapicMetadataStore(
name=_TEST_METADATASTORE,
)
create_metadata_store_mock.return_value = create_metadata_store_lro_mock
yield create_metadata_store_mock


@pytest.fixture
def get_context_mock():
with patch.object(MetadataServiceClient, "get_context") as get_context_mock:
Expand Down Expand Up @@ -364,6 +393,54 @@ def test_init_experiment_with_existing_metadataStore_and_context(
get_metadata_store_mock.assert_called_once_with(name=_TEST_METADATASTORE)
get_context_mock.assert_called_once_with(name=_TEST_CONTEXT_NAME)

def test_init_experiment_with_credentials(
self, get_metadata_store_mock, get_context_mock
):
creds = credentials.AnonymousCredentials()

aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
experiment=_TEST_EXPERIMENT,
credentials=creds,
)

assert (
metadata.metadata_service._experiment.api_client._transport._credentials
== creds
)

get_metadata_store_mock.assert_called_once_with(name=_TEST_METADATASTORE)
get_context_mock.assert_called_once_with(name=_TEST_CONTEXT_NAME)

def test_init_and_get_metadata_store_with_credentials(
self, get_metadata_store_mock
):
creds = credentials.AnonymousCredentials()

aiplatform.init(
project=_TEST_PROJECT, location=_TEST_LOCATION, credentials=creds
)

store = metadata._MetadataStore.get_or_create()

assert store.api_client._transport._credentials == creds

@pytest.mark.usefixtures(
"get_metadata_store_mock_raise_not_found_exception",
"create_metadata_store_mock",
)
def test_init_and_get_then_create_metadata_store_with_credentials(self):
creds = credentials.AnonymousCredentials()

aiplatform.init(
project=_TEST_PROJECT, location=_TEST_LOCATION, credentials=creds
)

store = metadata._MetadataStore.get_or_create()

assert store.api_client._transport._credentials == creds

def test_init_experiment_with_existing_description(
self, get_metadata_store_mock, get_context_mock
):
Expand Down