diff --git a/samples/model-builder/conftest.py b/samples/model-builder/conftest.py index d48547c9e3..3e6e2a1c84 100644 --- a/samples/model-builder/conftest.py +++ b/samples/model-builder/conftest.py @@ -576,6 +576,13 @@ def mock_get_execution(mock_execution): yield mock_get_execution +@pytest.fixture +def mock_create_execution(mock_execution): + with patch.object(aiplatform.Execution, "create") as mock_create_execution: + mock_create_execution.return_value = mock_execution + yield mock_create_execution + + @pytest.fixture def mock_get_artifact(mock_artifact): with patch.object(aiplatform, "Artifact") as mock_get_artifact: diff --git a/samples/model-builder/experiment_tracking/create_artifact_with_sdk_sample.py b/samples/model-builder/experiment_tracking/create_artifact_with_sdk_sample.py new file mode 100644 index 0000000000..3d67cacc17 --- /dev/null +++ b/samples/model-builder/experiment_tracking/create_artifact_with_sdk_sample.py @@ -0,0 +1,43 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Optional + +from google.cloud.aiplatform.metadata.schema.system import artifact_schema + + +# [START aiplatform_sdk_create_artifact_with_sdk_sample] +def create_artifact_sample( + project: str, + location: str, + uri: Optional[str] = None, + artifact_id: Optional[str] = None, + display_name: Optional[str] = None, + schema_version: Optional[str] = None, + description: Optional[str] = None, + metadata: Optional[Dict] = None, +): + system_artifact_schema = artifact_schema.Artifact( + uri=uri, + artifact_id=artifact_id, + display_name=display_name, + schema_version=schema_version, + description=description, + metadata=metadata, + ) + + return system_artifact_schema.create(project=project, location=location,) + + +# [END aiplatform_sdk_create_artifact_with_sdk_sample] diff --git a/samples/model-builder/experiment_tracking/create_artifact_with_sdk_sample_test.py b/samples/model-builder/experiment_tracking/create_artifact_with_sdk_sample_test.py new file mode 100644 index 0000000000..09b9249ff2 --- /dev/null +++ b/samples/model-builder/experiment_tracking/create_artifact_with_sdk_sample_test.py @@ -0,0 +1,49 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import create_artifact_with_sdk_sample + +from google.cloud.aiplatform.compat.types import artifact as gca_artifact + +import test_constants as constants + + +def test_create_artifact_with_sdk_sample(mock_artifact, mock_create_artifact): + artifact = create_artifact_with_sdk_sample.create_artifact_sample( + project=constants.PROJECT, + location=constants.LOCATION, + uri=constants.MODEL_ARTIFACT_URI, + artifact_id=constants.RESOURCE_ID, + display_name=constants.DISPLAY_NAME, + schema_version=constants.SCHEMA_VERSION, + description=constants.DESCRIPTION, + metadata=constants.METADATA, + ) + + mock_create_artifact.assert_called_with( + resource_id=constants.RESOURCE_ID, + schema_title="system.Artifact", + uri=constants.MODEL_ARTIFACT_URI, + display_name=constants.DISPLAY_NAME, + schema_version=constants.SCHEMA_VERSION, + description=constants.DESCRIPTION, + metadata=constants.METADATA, + state=gca_artifact.Artifact.State.LIVE, + metadata_store_id="default", + project=constants.PROJECT, + location=constants.LOCATION, + credentials=None, + ) + + assert artifact is mock_artifact diff --git a/samples/model-builder/experiment_tracking/create_execution_with_sdk_sample.py b/samples/model-builder/experiment_tracking/create_execution_with_sdk_sample.py new file mode 100644 index 0000000000..ac0faa7065 --- /dev/null +++ b/samples/model-builder/experiment_tracking/create_execution_with_sdk_sample.py @@ -0,0 +1,47 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, Optional + +from google.cloud import aiplatform +from google.cloud.aiplatform.metadata.schema.system import execution_schema + + +# [START aiplatform_sdk_create_execution_with_sdk_sample] +def create_execution_sample( + display_name: str, + input_artifacts: List[aiplatform.Artifact], + output_artifacts: List[aiplatform.Artifact], + project: str, + location: str, + execution_id: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + schema_version: Optional[str] = None, + description: Optional[str] = None, +): + aiplatform.init(project=project, location=location) + + with execution_schema.ContainerExecution( + display_name=display_name, + execution_id=execution_id, + metadata=metadata, + schema_version=schema_version, + description=description, + ).create() as execution: + execution.assign_input_artifacts(input_artifacts) + execution.assign_output_artifacts(output_artifacts) + return execution + + +# [END aiplatform_sdk_create_execution_with_sdk_sample] diff --git a/samples/model-builder/experiment_tracking/create_execution_with_sdk_sample_test.py b/samples/model-builder/experiment_tracking/create_execution_with_sdk_sample_test.py new file mode 100644 index 0000000000..54e9563a12 --- /dev/null +++ b/samples/model-builder/experiment_tracking/create_execution_with_sdk_sample_test.py @@ -0,0 +1,62 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import create_execution_with_sdk_sample + +from google.cloud.aiplatform.compat.types import execution as gca_execution + +import test_constants as constants + + +def test_create_execution_sample( + mock_sdk_init, mock_create_artifact, mock_create_execution, mock_execution, +): + + input_art = mock_create_artifact() + output_art = mock_create_artifact() + + exc = create_execution_with_sdk_sample.create_execution_sample( + display_name=constants.DISPLAY_NAME, + input_artifacts=[input_art], + output_artifacts=[output_art], + project=constants.PROJECT, + location=constants.LOCATION, + execution_id=constants.RESOURCE_ID, + metadata=constants.METADATA, + schema_version=constants.SCHEMA_VERSION, + description=constants.DESCRIPTION, + ) + + mock_sdk_init.assert_called_with( + project=constants.PROJECT, location=constants.LOCATION, + ) + + mock_create_execution.assert_called_with( + state=gca_execution.Execution.State.RUNNING, + schema_title="system.ContainerExecution", + resource_id=constants.RESOURCE_ID, + display_name=constants.DISPLAY_NAME, + schema_version=constants.SCHEMA_VERSION, + metadata=constants.METADATA, + description=constants.DESCRIPTION, + metadata_store_id="default", + project=None, + location=None, + credentials=None, + ) + + mock_execution.assign_input_artifacts.assert_called_with([input_art]) + mock_execution.assign_output_artifacts.assert_called_with([output_art]) + + assert exc is mock_execution