diff --git a/google/cloud/aiplatform/base.py b/google/cloud/aiplatform/base.py index 19b5f7d468..868431fbdc 100644 --- a/google/cloud/aiplatform/base.py +++ b/google/cloud/aiplatform/base.py @@ -46,6 +46,7 @@ from google.cloud.aiplatform import initializer from google.cloud.aiplatform import utils from google.cloud.aiplatform.compat.types import encryption_spec as gca_encryption_spec +from google.cloud.aiplatform.constants import base as base_constants from google.protobuf import json_format # This is the default retry callback to be used with get methods. @@ -499,7 +500,19 @@ def __init__( self.location = location or initializer.global_config.location self.credentials = credentials or initializer.global_config.credentials - self.api_client = self._instantiate_client(self.location, self.credentials) + appended_user_agent = None + if base_constants.USER_AGENT_SDK_COMMAND: + appended_user_agent = [ + f"sdk_command/{base_constants.USER_AGENT_SDK_COMMAND}" + ] + # Reset the value for the USER_AGENT_SDK_COMMAND to avoid counting future unrelated api calls. + base_constants.USER_AGENT_SDK_COMMAND = "" + + self.api_client = self._instantiate_client( + location=self.location, + credentials=self.credentials, + appended_user_agent=appended_user_agent, + ) @classmethod def _instantiate_client( diff --git a/google/cloud/aiplatform/constants/base.py b/google/cloud/aiplatform/constants/base.py index 373f66c003..8c1bc1b613 100644 --- a/google/cloud/aiplatform/constants/base.py +++ b/google/cloud/aiplatform/constants/base.py @@ -88,3 +88,7 @@ # Used in constructing the requests user_agent header for metrics reporting. USER_AGENT_PRODUCT = "model-builder" +# This field is used to pass the name of the specific SDK method +# that is being used for usage metrics tracking purposes. +# For more details on go/oneplatform-api-analytics +USER_AGENT_SDK_COMMAND = "" diff --git a/google/cloud/aiplatform/metadata/artifact.py b/google/cloud/aiplatform/metadata/artifact.py index 5f4f94e18b..c995ecf85b 100644 --- a/google/cloud/aiplatform/metadata/artifact.py +++ b/google/cloud/aiplatform/metadata/artifact.py @@ -28,6 +28,7 @@ from google.cloud.aiplatform.compat.types import ( metadata_service as gca_metadata_service, ) +from google.cloud.aiplatform.constants import base as base_constants from google.cloud.aiplatform.metadata import metadata_store from google.cloud.aiplatform.metadata import resource from google.cloud.aiplatform.metadata import utils as metadata_utils @@ -114,6 +115,7 @@ def _create_resource( artifact_id=resource_id, ) + # TODO() refactor code to move _create to _Resource class. @classmethod def _create( cls, @@ -175,7 +177,19 @@ def _create( Instantiated representation of the managed Metadata resource. """ - api_client = cls._instantiate_client(location=location, credentials=credentials) + appended_user_agent = [] + if base_constants.USER_AGENT_SDK_COMMAND: + appended_user_agent = [ + f"sdk_command/{base_constants.USER_AGENT_SDK_COMMAND}" + ] + # Reset the value for the USER_AGENT_SDK_COMMAND to avoid counting future unrelated api calls. + base_constants.USER_AGENT_SDK_COMMAND = "" + + api_client = cls._instantiate_client( + location=location, + credentials=credentials, + appended_user_agent=appended_user_agent, + ) parent = utils.full_resource_name( resource_name=metadata_store_id, @@ -311,6 +325,13 @@ def create( Returns: Artifact: Instantiated representation of the managed Metadata Artifact. """ + # Add User Agent Header for metrics tracking if one is not specified + # If one is already specified this call was initiated by a sub class. + if not base_constants.USER_AGENT_SDK_COMMAND: + base_constants.USER_AGENT_SDK_COMMAND = ( + "aiplatform.metadata.artifact.Artifact.create" + ) + return cls._create( resource_id=resource_id, schema_title=schema_title, diff --git a/google/cloud/aiplatform/metadata/context.py b/google/cloud/aiplatform/metadata/context.py index d1c7dea99c..c827f865d6 100644 --- a/google/cloud/aiplatform/metadata/context.py +++ b/google/cloud/aiplatform/metadata/context.py @@ -23,6 +23,7 @@ from google.cloud.aiplatform import base from google.cloud.aiplatform import utils +from google.cloud.aiplatform.constants import base as base_constants from google.cloud.aiplatform.metadata import utils as metadata_utils from google.cloud.aiplatform.compat.types import context as gca_context from google.cloud.aiplatform.compat.types import ( @@ -136,6 +137,13 @@ def create( Returns: Context: Instantiated representation of the managed Metadata Context. """ + # Add User Agent Header for metrics tracking if one is not specified + # If one is already specified this call was initiated by a sub class. + if not base_constants.USER_AGENT_SDK_COMMAND: + base_constants.USER_AGENT_SDK_COMMAND = ( + "aiplatform.metadata.context.Context.create" + ) + return cls._create( resource_id=resource_id, schema_title=schema_title, @@ -202,7 +210,19 @@ def _create( Instantiated representation of the managed Metadata resource. """ - api_client = cls._instantiate_client(location=location, credentials=credentials) + appended_user_agent = [] + if base_constants.USER_AGENT_SDK_COMMAND: + appended_user_agent = [ + f"sdk_command/{base_constants.USER_AGENT_SDK_COMMAND}" + ] + # Reset the value for the USER_AGENT_SDK_COMMAND to avoid counting future unrelated api calls. + base_constants.USER_AGENT_SDK_COMMAND = "" + + api_client = cls._instantiate_client( + location=location, + credentials=credentials, + appended_user_agent=appended_user_agent, + ) parent = utils.full_resource_name( resource_name=metadata_store_id, diff --git a/google/cloud/aiplatform/metadata/execution.py b/google/cloud/aiplatform/metadata/execution.py index 9a85bce36f..5a4e19f4f8 100644 --- a/google/cloud/aiplatform/metadata/execution.py +++ b/google/cloud/aiplatform/metadata/execution.py @@ -20,7 +20,6 @@ import proto from google.auth import credentials as auth_credentials -from google.cloud.aiplatform import base from google.cloud.aiplatform import models from google.cloud.aiplatform import utils from google.cloud.aiplatform.compat.types import event as gca_event @@ -28,6 +27,7 @@ from google.cloud.aiplatform.compat.types import ( metadata_service as gca_metadata_service, ) +from google.cloud.aiplatform.constants import base as base_constants from google.cloud.aiplatform.metadata import artifact from google.cloud.aiplatform.metadata import metadata_store from google.cloud.aiplatform.metadata import resource @@ -142,18 +142,110 @@ def create( Execution: Instantiated representation of the managed Metadata Execution. """ - self = cls._empty_constructor( - project=project, location=location, credentials=credentials + # Add User Agent Header for metrics tracking if one is not specified + # If one is already specified this call was initiated by a sub class. + if not base_constants.USER_AGENT_SDK_COMMAND: + base_constants.USER_AGENT_SDK_COMMAND = ( + "aiplatform.metadata.execution.Execution.create" + ) + + return cls._create( + resource_id=resource_id, + schema_title=schema_title, + display_name=display_name, + schema_version=schema_version, + description=description, + metadata=metadata, + state=state, + metadata_store_id=metadata_store_id, + project=project, + location=location, + credentials=credentials, + ) + + # TODO() refactor code to move _create to _Resource class. + @classmethod + def _create( + cls, + schema_title: str, + *, + state: gca_execution.Execution.State = gca_execution.Execution.State.RUNNING, + resource_id: Optional[str] = None, + display_name: Optional[str] = None, + schema_version: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + description: Optional[str] = None, + metadata_store_id: str = "default", + project: Optional[str] = None, + location: Optional[str] = None, + credentials=Optional[auth_credentials.Credentials], + ) -> "Execution": + """ + Creates a new Metadata Execution. + + Args: + schema_title (str): + Required. schema_title identifies the schema title used by the Execution. + state (gca_execution.Execution.State.RUNNING): + Optional. State of this Execution. Defaults to RUNNING. + resource_id (str): + Optional. The portion of the Execution name with + the format. This is globally unique in a metadataStore: + projects/123/locations/us-central1/metadataStores//executions/. + display_name (str): + Optional. The user-defined name of the Execution. + schema_version (str): + Optional. schema_version specifies the version used by the Execution. + If not set, defaults to use the latest version. + metadata (Dict): + Optional. Contains the metadata information that will be stored in the Execution. + description (str): + Optional. Describes the purpose of the Execution to be created. + metadata_store_id (str): + Optional. The portion of the resource name with + the format: + projects/123/locations/us-central1/metadataStores//artifacts/ + If not provided, the MetadataStore's ID will be set to "default". + project (str): + Optional. Project used to create this Execution. Overrides project set in + aiplatform.init. + location (str): + Optional. Location used to create this Execution. Overrides location set in + aiplatform.init. + credentials (auth_credentials.Credentials): + Optional. Custom credentials used to create this Execution. Overrides + credentials set in aiplatform.init. + + Returns: + Execution: Instantiated representation of the managed Metadata Execution. + + """ + appended_user_agent = [] + if base_constants.USER_AGENT_SDK_COMMAND: + appended_user_agent = [ + f"sdk_command/{base_constants.USER_AGENT_SDK_COMMAND}" + ] + # Reset the value for the USER_AGENT_SDK_COMMAND to avoid counting future unrelated api calls. + base_constants.USER_AGENT_SDK_COMMAND = "" + + api_client = cls._instantiate_client( + location=location, + credentials=credentials, + appended_user_agent=appended_user_agent, + ) + + parent = utils.full_resource_name( + resource_name=metadata_store_id, + resource_noun=metadata_store._MetadataStore._resource_noun, + parse_resource_name_method=metadata_store._MetadataStore._parse_resource_name, + format_resource_name_method=metadata_store._MetadataStore._format_resource_name, + project=project, + location=location, ) - super(base.VertexAiResourceNounWithFutureManager, self).__init__() resource = Execution._create_resource( - client=self.api_client, - parent=metadata_store._MetadataStore._format_resource_name( - project=self.project, - location=self.location, - metadata_store=metadata_store_id, - ), + client=api_client, + parent=parent, schema_title=schema_title, resource_id=resource_id, metadata=metadata, @@ -162,6 +254,9 @@ def create( schema_version=schema_version, state=state, ) + self = cls._empty_constructor( + project=project, location=location, credentials=credentials + ) self._gca_resource = resource return self diff --git a/google/cloud/aiplatform/metadata/metadata_store.py b/google/cloud/aiplatform/metadata/metadata_store.py index 2f0c8e2955..ab2e7ec305 100644 --- a/google/cloud/aiplatform/metadata/metadata_store.py +++ b/google/cloud/aiplatform/metadata/metadata_store.py @@ -25,6 +25,7 @@ from google.cloud.aiplatform import compat from google.cloud.aiplatform import utils from google.cloud.aiplatform.compat.types import metadata_store as gca_metadata_store +from google.cloud.aiplatform.constants import base as base_constants class _MetadataStore(base.VertexAiResourceNounWithFutureManager): @@ -115,7 +116,6 @@ def get_or_create( Instantiated representation of the managed metadata store resource. """ - store = cls._get( metadata_store_name=metadata_store_id, project=project, @@ -176,7 +176,20 @@ def _create( Instantiated representation of the managed metadata store resource. """ - api_client = cls._instantiate_client(location=location, credentials=credentials) + appended_user_agent = [] + if base_constants.USER_AGENT_SDK_COMMAND: + appended_user_agent = [ + f"sdk_command/{base_constants.USER_AGENT_SDK_COMMAND}" + ] + # Reset the value for the USER_AGENT_SDK_COMMAND to avoid counting future unrelated api calls. + base_constants.USER_AGENT_SDK_COMMAND = "" + + api_client = cls._instantiate_client( + location=location, + credentials=credentials, + appended_user_agent=appended_user_agent, + ) + gapic_metadata_store = gca_metadata_store.MetadataStore( encryption_spec=initializer.global_config.get_encryption_spec( encryption_spec_key_name=encryption_spec_key_name, diff --git a/google/cloud/aiplatform/metadata/schema/base_artifact.py b/google/cloud/aiplatform/metadata/schema/base_artifact.py index 357ce9d58a..e4b90d8b6a 100644 --- a/google/cloud/aiplatform/metadata/schema/base_artifact.py +++ b/google/cloud/aiplatform/metadata/schema/base_artifact.py @@ -23,6 +23,7 @@ from google.cloud.aiplatform.compat.types import artifact as gca_artifact from google.cloud.aiplatform.metadata import artifact +from google.cloud.aiplatform.constants import base as base_constants from google.cloud.aiplatform.metadata import constants @@ -114,6 +115,11 @@ def _init_with_resource_name( Artifact name with the following format, this is globally unique in a metadataStore: projects/123/locations/us-central1/metadataStores//artifacts/. """ + # Add User Agent Header for metrics tracking if one is not specified + # If one is already specified this call was initiated by a sub class. + if not base_constants.USER_AGENT_SDK_COMMAND: + base_constants.USER_AGENT_SDK_COMMAND = "aiplatform.metadata.schema.base_artifact.BaseArtifactSchema._init_with_resource_name" + super(BaseArtifactSchema, self).__init__(artifact_name=artifact_name) def create( @@ -144,6 +150,10 @@ def create( Returns: Artifact: Instantiated representation of the managed Metadata Artifact. """ + # Add User Agent Header for metrics tracking. + base_constants.USER_AGENT_SDK_COMMAND = ( + "aiplatform.metadata.schema.base_artifact.BaseArtifactSchema.create" + ) # Check if metadata exists to avoid proto read error metadata = None diff --git a/google/cloud/aiplatform/metadata/schema/base_context.py b/google/cloud/aiplatform/metadata/schema/base_context.py index b6d7f5b4d7..618bda3b60 100644 --- a/google/cloud/aiplatform/metadata/schema/base_context.py +++ b/google/cloud/aiplatform/metadata/schema/base_context.py @@ -22,6 +22,7 @@ from google.auth import credentials as auth_credentials from google.cloud.aiplatform.compat.types import context as gca_context +from google.cloud.aiplatform.constants import base as base_constants from google.cloud.aiplatform.metadata import constants from google.cloud.aiplatform.metadata import context @@ -91,6 +92,11 @@ def _init_with_resource_name( Context name with the following format, this is globally unique in a metadataStore: projects/123/locations/us-central1/metadataStores//contexts/. """ + # Add User Agent Header for metrics tracking if one is not specified + # If one is already specified this call was initiated by a sub class. + if not base_constants.USER_AGENT_SDK_COMMAND: + base_constants.USER_AGENT_SDK_COMMAND = "aiplatform.metadata.schema.base_context.BaseContextSchema._init_with_resource_name" + super(BaseContextSchema, self).__init__(resource_name=context_name) def create( @@ -122,6 +128,11 @@ def create( Context: Instantiated representation of the managed Metadata Context. """ + # Add User Agent Header for metrics tracking. + base_constants.USER_AGENT_SDK_COMMAND = ( + "aiplatform.metadata.schema.base_context.BaseContextSchema.create" + ) + # Check if metadata exists to avoid proto read error metadata = None if self._gca_resource.metadata: diff --git a/google/cloud/aiplatform/metadata/schema/base_execution.py b/google/cloud/aiplatform/metadata/schema/base_execution.py index 1cbf66b825..2f392c856c 100644 --- a/google/cloud/aiplatform/metadata/schema/base_execution.py +++ b/google/cloud/aiplatform/metadata/schema/base_execution.py @@ -22,6 +22,7 @@ from google.auth import credentials as auth_credentials from google.cloud.aiplatform.compat.types import execution as gca_execution +from google.cloud.aiplatform.constants import base as base_constants from google.cloud.aiplatform.metadata import constants from google.cloud.aiplatform.metadata import execution from google.cloud.aiplatform.metadata import metadata @@ -100,6 +101,11 @@ def _init_with_resource_name( The Execution name with the following format, this is globally unique in a metadataStore. projects/123/locations/us-central1/metadataStores//executions/. """ + # Add User Agent Header for metrics tracking if one is not specified + # If one is already specified this call was initiated by a sub class. + if not base_constants.USER_AGENT_SDK_COMMAND: + base_constants.USER_AGENT_SDK_COMMAND = "aiplatform.metadata.schema.base_execution.BaseExecutionSchema._init_with_resource_name" + super(BaseExecutionSchema, self).__init__(execution_name=execution_name) def create( @@ -131,6 +137,12 @@ def create( Execution: Instantiated representation of the managed Metadata Execution. """ + # Add User Agent Header for metrics tracking if one is not specified + # If one is already specified this call was initiated by a sub class. + base_constants.USER_AGENT_SDK_COMMAND = ( + "aiplatform.metadata.schema.base_execution.BaseExecutionSchema.create" + ) + # Check if metadata exists to avoid proto read error metadata = None if self._gca_resource.metadata: @@ -208,6 +220,11 @@ def start_execution( Raises: ValueError: If metadata_store_id other than 'default' is provided. """ + # Add User Agent Header for metrics tracking if one is not specified + # If one is already specified this call was initiated by a sub class. + + base_constants.USER_AGENT_SDK_COMMAND = "aiplatform.metadata.schema.base_execution.BaseExecutionSchema.start_execution" + if metadata_store_id != "default": raise ValueError( f"metadata_store_id {metadata_store_id} is not supported. Only the default MetadataStore ID is supported." diff --git a/tests/unit/aiplatform/test_metadata_schema.py b/tests/unit/aiplatform/test_metadata_schema.py index e85229c555..65de1304ea 100644 --- a/tests/unit/aiplatform/test_metadata_schema.py +++ b/tests/unit/aiplatform/test_metadata_schema.py @@ -99,6 +99,38 @@ def get_artifact_mock(): yield get_artifact_mock +@pytest.fixture +def initializer_create_client_mock(): + with patch.object( + initializer.global_config, "create_client" + ) as initializer_create_client_mock: + yield initializer_create_client_mock + + +@pytest.fixture +def base_artifact_init_with_resouce_name_mock(): + with patch.object( + base_artifact.BaseArtifactSchema, "_init_with_resource_name" + ) as base_artifact_init_with_resouce_name_mock: + yield base_artifact_init_with_resouce_name_mock + + +@pytest.fixture +def base_execution_init_with_resouce_name_mock(): + with patch.object( + base_execution.BaseExecutionSchema, "_init_with_resource_name" + ) as base_execution_init_with_resouce_name_mock: + yield base_execution_init_with_resouce_name_mock + + +@pytest.fixture +def base_context_init_with_resouce_name_mock(): + with patch.object( + base_context.BaseContextSchema, "_init_with_resource_name" + ) as base_context_init_with_resouce_name_mock: + yield base_context_init_with_resouce_name_mock + + @pytest.fixture def get_execution_mock(): with patch.object(MetadataServiceClient, "get_execution") as get_execution_mock: @@ -246,6 +278,59 @@ class TestArtifact(base_artifact.BaseArtifactSchema): assert kwargs["artifact"].metadata == _TEST_UPDATED_METADATA assert kwargs["artifact"].state == _TEST_ARTIFACT_STATE + @pytest.mark.usefixtures( + "base_artifact_init_with_resouce_name_mock", + "initializer_create_client_mock", + "create_artifact_mock", + "get_artifact_mock", + ) + def test_artifact_create_call_sets_the_user_agent_header( + self, initializer_create_client_mock + ): + aiplatform.init(project=_TEST_PROJECT) + + class TestArtifact(base_artifact.BaseArtifactSchema): + schema_title = _TEST_SCHEMA_TITLE + + artifact = TestArtifact( + uri=_TEST_URI, + display_name=_TEST_DISPLAY_NAME, + description=_TEST_DESCRIPTION, + metadata=_TEST_UPDATED_METADATA, + state=_TEST_ARTIFACT_STATE, + ) + artifact.create() + _, _, kwargs = initializer_create_client_mock.mock_calls[0] + assert kwargs["appended_user_agent"] == [ + "sdk_command/aiplatform.metadata.schema.base_artifact.BaseArtifactSchema.create" + ] + + @pytest.mark.usefixtures( + "initializer_create_client_mock", + "create_artifact_mock", + "get_artifact_mock", + ) + def test_artifact_init_call_sets_the_user_agent_header( + self, initializer_create_client_mock + ): + aiplatform.init(project=_TEST_PROJECT) + + class TestArtifact(base_artifact.BaseArtifactSchema): + schema_title = _TEST_SCHEMA_TITLE + + artifact = TestArtifact( + uri=_TEST_URI, + display_name=_TEST_DISPLAY_NAME, + description=_TEST_DESCRIPTION, + metadata=_TEST_UPDATED_METADATA, + state=_TEST_ARTIFACT_STATE, + ) + artifact._init_with_resource_name(artifact_name=_TEST_ARTIFACT_NAME) + _, _, kwargs = initializer_create_client_mock.mock_calls[0] + assert kwargs["appended_user_agent"] == [ + "sdk_command/aiplatform.metadata.schema.base_artifact.BaseArtifactSchema._init_with_resource_name" + ] + @pytest.mark.usefixtures("google_auth_mock") class TestMetadataBaseExecutionSchema: @@ -316,6 +401,204 @@ class TestExecution(base_execution.BaseExecutionSchema): assert kwargs["execution"].description == _TEST_DESCRIPTION assert kwargs["execution"].metadata == _TEST_UPDATED_METADATA + @pytest.mark.usefixtures( + "base_execution_init_with_resouce_name_mock", + "initializer_create_client_mock", + "create_execution_mock", + "get_execution_mock", + ) + def test_execution_create_call_sets_the_user_agent_header( + self, initializer_create_client_mock + ): + aiplatform.init(project=_TEST_PROJECT) + + class TestExecution(base_execution.BaseExecutionSchema): + schema_title = _TEST_SCHEMA_TITLE + + execution = TestExecution( + state=_TEST_EXECUTION_STATE, + display_name=_TEST_DISPLAY_NAME, + description=_TEST_DESCRIPTION, + metadata=_TEST_UPDATED_METADATA, + ) + execution.create(metadata_store_id=_TEST_METADATA_STORE) + _, _, kwargs = initializer_create_client_mock.mock_calls[0] + assert kwargs["appended_user_agent"] == [ + "sdk_command/aiplatform.metadata.schema.base_execution.BaseExecutionSchema.create" + ] + + @pytest.mark.usefixtures( + "base_execution_init_with_resouce_name_mock", + "initializer_create_client_mock", + "create_execution_mock", + "get_execution_mock", + ) + def test_execution_start_execution_call_sets_the_user_agent_header( + self, initializer_create_client_mock + ): + aiplatform.init(project=_TEST_PROJECT) + + class TestExecution(base_execution.BaseExecutionSchema): + schema_title = _TEST_SCHEMA_TITLE + + execution = TestExecution( + state=_TEST_EXECUTION_STATE, + display_name=_TEST_DISPLAY_NAME, + description=_TEST_DESCRIPTION, + metadata=_TEST_UPDATED_METADATA, + ) + execution.start_execution() + _, _, kwargs = initializer_create_client_mock.mock_calls[0] + assert kwargs["appended_user_agent"] == [ + "sdk_command/aiplatform.metadata.schema.base_execution.BaseExecutionSchema.start_execution" + ] + + @pytest.mark.usefixtures( + "initializer_create_client_mock", + "create_execution_mock", + "get_execution_mock", + ) + def test_execution_init_call_sets_the_user_agent_header( + self, initializer_create_client_mock + ): + aiplatform.init(project=_TEST_PROJECT) + + class TestExecution(base_execution.BaseExecutionSchema): + schema_title = _TEST_SCHEMA_TITLE + + execution = TestExecution( + state=_TEST_EXECUTION_STATE, + display_name=_TEST_DISPLAY_NAME, + description=_TEST_DESCRIPTION, + metadata=_TEST_UPDATED_METADATA, + ) + execution._init_with_resource_name(execution_name=_TEST_EXECUTION_NAME) + _, _, kwargs = initializer_create_client_mock.mock_calls[0] + assert kwargs["appended_user_agent"] == [ + "sdk_command/aiplatform.metadata.schema.base_execution.BaseExecutionSchema._init_with_resource_name" + ] + + +@pytest.mark.usefixtures("google_auth_mock") +class TestMetadataBaseContextSchema: + def setup_method(self): + reload(initializer) + reload(metadata) + reload(aiplatform) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + def test_base_context_class_instatiated_uses_schema_title(self): + class TestContext(base_context.BaseContextSchema): + schema_title = _TEST_SCHEMA_TITLE + + context = TestContext() + assert context.schema_title == _TEST_SCHEMA_TITLE + + def test_base_context_class_parameters_overrides_default_values(self): + class TestContext(base_context.BaseContextSchema): + schema_title = _TEST_SCHEMA_TITLE + + context = TestContext( + schema_version=_TEST_SCHEMA_VERSION, + context_id=_TEST_CONTEXT_ID, + display_name=_TEST_DISPLAY_NAME, + description=_TEST_DESCRIPTION, + metadata=_TEST_UPDATED_METADATA, + ) + assert context.schema_version == _TEST_SCHEMA_VERSION + assert context.context_id == _TEST_CONTEXT_ID + assert context.schema_title == _TEST_SCHEMA_TITLE + assert context.display_name == _TEST_DISPLAY_NAME + assert context.description == _TEST_DESCRIPTION + assert context.metadata == _TEST_UPDATED_METADATA + + def test_base_context_class_without_schema_title_raises_error(self): + with pytest.raises(TypeError): + base_context.BaseContextSchema() + + @pytest.mark.usefixtures("create_context_mock", "get_context_mock") + def test_base_context_create_is_called_with_default_parameters( + self, create_context_mock + ): + aiplatform.init(project=_TEST_PROJECT) + + class TestContext(base_context.BaseContextSchema): + schema_title = _TEST_SCHEMA_TITLE + + context = TestContext( + schema_version=_TEST_SCHEMA_VERSION, + context_id=_TEST_CONTEXT_ID, + display_name=_TEST_DISPLAY_NAME, + description=_TEST_DESCRIPTION, + metadata=_TEST_UPDATED_METADATA, + ) + context.create(metadata_store_id=_TEST_METADATA_STORE) + create_context_mock.assert_called_once_with( + parent=f"{_TEST_PARENT}/metadataStores/{_TEST_METADATA_STORE}", + context=mock.ANY, + context_id=_TEST_CONTEXT_ID, + ) + _, _, kwargs = create_context_mock.mock_calls[0] + assert kwargs["context"].schema_title == _TEST_SCHEMA_TITLE + assert kwargs["context"].display_name == _TEST_DISPLAY_NAME + assert kwargs["context"].description == _TEST_DESCRIPTION + assert kwargs["context"].metadata == _TEST_UPDATED_METADATA + + @pytest.mark.usefixtures( + "base_context_init_with_resouce_name_mock", + "initializer_create_client_mock", + "create_context_mock", + "get_context_mock", + ) + def test_base_context_create_call_sets_the_user_agent_header( + self, initializer_create_client_mock + ): + aiplatform.init(project=_TEST_PROJECT) + + class TestContext(base_context.BaseContextSchema): + schema_title = _TEST_SCHEMA_TITLE + + context = TestContext( + schema_version=_TEST_SCHEMA_VERSION, + context_id=_TEST_CONTEXT_ID, + display_name=_TEST_DISPLAY_NAME, + description=_TEST_DESCRIPTION, + metadata=_TEST_UPDATED_METADATA, + ) + context.create() + _, _, kwargs = initializer_create_client_mock.mock_calls[0] + assert kwargs["appended_user_agent"] == [ + "sdk_command/aiplatform.metadata.schema.base_context.BaseContextSchema.create" + ] + + @pytest.mark.usefixtures( + "initializer_create_client_mock", + "create_context_mock", + "get_context_mock", + ) + def test_base_context_init_call_sets_the_user_agent_header( + self, initializer_create_client_mock + ): + aiplatform.init(project=_TEST_PROJECT) + + class TestContext(base_context.BaseContextSchema): + schema_title = _TEST_SCHEMA_TITLE + + context = TestContext( + schema_version=_TEST_SCHEMA_VERSION, + context_id=_TEST_CONTEXT_ID, + display_name=_TEST_DISPLAY_NAME, + description=_TEST_DESCRIPTION, + metadata=_TEST_UPDATED_METADATA, + ) + context._init_with_resource_name(context_name=_TEST_CONTEXT_NAME) + _, _, kwargs = initializer_create_client_mock.mock_calls[0] + assert kwargs["appended_user_agent"] == [ + "sdk_command/aiplatform.metadata.schema.base_context.BaseContextSchema._init_with_resource_name" + ] + @pytest.mark.usefixtures("google_auth_mock") class TestMetadataGoogleArtifactSchema: