diff --git a/google/cloud/aiplatform/initializer.py b/google/cloud/aiplatform/initializer.py index 18341bde46..9f0ad719f9 100644 --- a/google/cloud/aiplatform/initializer.py +++ b/google/cloud/aiplatform/initializer.py @@ -92,11 +92,19 @@ def init( if metadata.metadata_service.experiment_name: logging.info("project/location updated, reset Metadata config.") metadata.metadata_service.reset() + if project: self._project = project if location: utils.validate_region(location) self._location = location + if staging_bucket: + self._staging_bucket = staging_bucket + if credentials: + self._credentials = credentials + if encryption_spec_key_name: + self._encryption_spec_key_name = encryption_spec_key_name + if experiment: metadata.metadata_service.set_experiment( experiment=experiment, description=experiment_description @@ -105,12 +113,6 @@ def init( raise ValueError( "Experiment name needs to be set in `init` in order to add experiment descriptions." ) - if staging_bucket: - self._staging_bucket = staging_bucket - if credentials: - self._credentials = credentials - if encryption_spec_key_name: - self._encryption_spec_key_name = encryption_spec_key_name def get_encryption_spec( self, diff --git a/google/cloud/aiplatform/metadata/metadata_store.py b/google/cloud/aiplatform/metadata/metadata_store.py index 494d31aca4..3327f47d1f 100644 --- a/google/cloud/aiplatform/metadata/metadata_store.py +++ b/google/cloud/aiplatform/metadata/metadata_store.py @@ -205,7 +205,7 @@ def _get( project: Optional[str] = None, location: Optional[str] = None, credentials: Optional[auth_credentials.Credentials] = None, - ) -> "Optional[_MetadataStore]": + ) -> Optional["_MetadataStore"]: """Returns a MetadataStore resource. Args: diff --git a/tests/unit/aiplatform/test_metadata.py b/tests/unit/aiplatform/test_metadata.py index 9a930dd3f5..36de8938b4 100644 --- a/tests/unit/aiplatform/test_metadata.py +++ b/tests/unit/aiplatform/test_metadata.py @@ -16,10 +16,13 @@ # from importlib import reload +from unittest import mock from unittest.mock import patch, call import pytest from google.api_core import exceptions +from google.api_core import operation +from google.auth import credentials from google.cloud import aiplatform from google.cloud.aiplatform import initializer @@ -106,6 +109,32 @@ def get_metadata_store_mock(): yield get_metadata_store_mock +@pytest.fixture +def get_metadata_store_mock_raise_not_found_exception(): + with patch.object( + MetadataServiceClient, "get_metadata_store" + ) as get_metadata_store_mock: + get_metadata_store_mock.side_effect = [ + exceptions.NotFound("Test store not found."), + GapicMetadataStore(name=_TEST_METADATASTORE,), + ] + + yield get_metadata_store_mock + + +@pytest.fixture +def create_metadata_store_mock(): + with patch.object( + MetadataServiceClient, "create_metadata_store" + ) as create_metadata_store_mock: + create_metadata_store_lro_mock = mock.Mock(operation.Operation) + create_metadata_store_lro_mock.result.return_value = GapicMetadataStore( + name=_TEST_METADATASTORE, + ) + create_metadata_store_mock.return_value = create_metadata_store_lro_mock + yield create_metadata_store_mock + + @pytest.fixture def get_context_mock(): with patch.object(MetadataServiceClient, "get_context") as get_context_mock: @@ -364,6 +393,54 @@ def test_init_experiment_with_existing_metadataStore_and_context( get_metadata_store_mock.assert_called_once_with(name=_TEST_METADATASTORE) get_context_mock.assert_called_once_with(name=_TEST_CONTEXT_NAME) + def test_init_experiment_with_credentials( + self, get_metadata_store_mock, get_context_mock + ): + creds = credentials.AnonymousCredentials() + + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + experiment=_TEST_EXPERIMENT, + credentials=creds, + ) + + assert ( + metadata.metadata_service._experiment.api_client._transport._credentials + == creds + ) + + get_metadata_store_mock.assert_called_once_with(name=_TEST_METADATASTORE) + get_context_mock.assert_called_once_with(name=_TEST_CONTEXT_NAME) + + def test_init_and_get_metadata_store_with_credentials( + self, get_metadata_store_mock + ): + creds = credentials.AnonymousCredentials() + + aiplatform.init( + project=_TEST_PROJECT, location=_TEST_LOCATION, credentials=creds + ) + + store = metadata._MetadataStore.get_or_create() + + assert store.api_client._transport._credentials == creds + + @pytest.mark.usefixtures( + "get_metadata_store_mock_raise_not_found_exception", + "create_metadata_store_mock", + ) + def test_init_and_get_then_create_metadata_store_with_credentials(self): + creds = credentials.AnonymousCredentials() + + aiplatform.init( + project=_TEST_PROJECT, location=_TEST_LOCATION, credentials=creds + ) + + store = metadata._MetadataStore.get_or_create() + + assert store.api_client._transport._credentials == creds + def test_init_experiment_with_existing_description( self, get_metadata_store_mock, get_context_mock ):