Skip to content

Commit

Permalink
docs: samples for model serialization
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 504033183
  • Loading branch information
jaycee-li authored and copybara-github committed Jan 23, 2023
1 parent f38ddc2 commit 7997094
Show file tree
Hide file tree
Showing 12 changed files with 459 additions and 0 deletions.
68 changes: 68 additions & 0 deletions samples/model-builder/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,30 @@ def mock_artifacts():
yield mock


@pytest.fixture
def mock_experiment_models():
mock = MagicMock()
yield mock


@pytest.fixture
def mock_model_info():
mock = MagicMock()
yield mock


@pytest.fixture
def mock_ml_model():
mock = MagicMock()
yield mock


@pytest.fixture
def mock_experiment_model():
mock = MagicMock(aiplatform.metadata.schema.google.artifact_schema.ExperimentModel)
yield mock


@pytest.fixture
def mock_get_execution(mock_execution):
with patch.object(aiplatform, "Execution") as mock_get_execution:
Expand Down Expand Up @@ -870,6 +894,13 @@ def mock_log_classification_metrics():
yield mock_log_metrics


@pytest.fixture
def mock_log_model():
with patch.object(aiplatform, "log_model") as mock_log_metrics:
mock_log_metrics.return_value = None
yield mock_log_metrics


@pytest.fixture
def mock_log_pipeline_job():
with patch.object(aiplatform, "log") as mock_log_pipeline_job:
Expand Down Expand Up @@ -944,6 +975,43 @@ def mock_get_artifacts(mock_artifacts, mock_experiment_run):
yield mock_get_artifacts


@pytest.fixture
def mock_get_experiment_models(mock_experiment_models, mock_experiment_run):
with patch.object(
mock_experiment_run, "get_experiment_models"
) as mock_get_experiment_models:
mock_get_experiment_models.return_value = mock_experiment_models
yield mock_get_experiment_models


@pytest.fixture
def mock_get_experiment_model(mock_experiment_model):
with patch.object(aiplatform, "get_experiment_model") as mock_get_experiment_model:
mock_get_experiment_model.return_value = mock_experiment_model
yield mock_get_experiment_model


@pytest.fixture
def mock_get_model_info(mock_experiment_model, mock_model_info):
with patch.object(mock_experiment_model, "get_model_info") as mock_get_model_info:
mock_get_model_info.return_value = mock_model_info
yield mock_get_model_info


@pytest.fixture
def mock_load_model(mock_experiment_model, mock_ml_model):
with patch.object(mock_experiment_model, "load_model") as mock_load_model:
mock_load_model.return_value = mock_ml_model
yield mock_load_model


@pytest.fixture
def mock_register_model(mock_experiment_model, mock_model):
with patch.object(mock_experiment_model, "register_model") as mock_register_model:
mock_register_model.return_value = mock_model
yield mock_register_model


"""
----------------------------------------------------------------------------
Model Versioning Fixtures
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Union

from google.cloud import aiplatform


# [START aiplatform_sdk_get_experiment_run_models_sample]
def get_experiment_run_models_sample(
run_name: str,
experiment: Union[str, aiplatform.Experiment],
project: str,
location: str,
) -> List["ExperimentModel"]: # noqa: F821
experiment_run = aiplatform.ExperimentRun(
run_name=run_name, experiment=experiment, project=project, location=location
)

return experiment_run.get_experiment_models()


# [END aiplatform_sdk_get_experiment_run_models_sample]
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import get_experiment_run_models_sample

import pytest

import test_constants as constants


@pytest.mark.usefixtures("mock_get_run")
def test_get_experiment_run_models_sample(
mock_get_experiment_models, mock_experiment_models
):

experiment_models = (
get_experiment_run_models_sample.get_experiment_run_models_sample(
run_name=constants.EXPERIMENT_RUN_NAME,
experiment=constants.EXPERIMENT_NAME,
project=constants.PROJECT,
location=constants.LOCATION,
)
)

mock_get_experiment_models.assert_called_once()

assert experiment_models is mock_experiment_models
33 changes: 33 additions & 0 deletions samples/model-builder/experiment_tracking/get_model_info_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict

from google.cloud import aiplatform


# [START aiplatform_sdk_get_model_info_sample]
def get_model_info_sample(
artifact_id: str,
project: str,
location: str,
) -> Dict[str, Any]:
experiment_model = aiplatform.get_experiment_model(
artifact_id=artifact_id, project=project, location=location
)

return experiment_model.get_model_info()


# [END aiplatform_sdk_get_model_info_sample]
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import get_model_info_sample

import pytest

import test_constants as constants


@pytest.mark.usefixtures("mock_get_run")
def test_get_model_info_sample(
mock_get_experiment_model, mock_get_model_info, mock_model_info
):

model_info = get_model_info_sample.get_model_info_sample(
artifact_id=constants.EXPERIMENT_MODEL_ID,
project=constants.PROJECT,
location=constants.LOCATION,
)

mock_get_experiment_model.assert_called_once_with(
artifact_id=constants.EXPERIMENT_MODEL_ID,
project=constants.PROJECT,
location=constants.LOCATION,
)
mock_get_model_info.assert_called_once()

assert model_info is mock_model_info
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Union

from google.cloud import aiplatform


# [START aiplatform_sdk_load_experiment_model_sample]
def load_experiment_model_sample(
artifact_id: str,
project: str,
location: str,
) -> Union["sklearn.base.BaseEstimator", "xgb.Booster", "tf.Module"]: # noqa: F821:
experiment_model = aiplatform.get_experiment_model(
artifact_id=artifact_id, project=project, location=location
)

return experiment_model.load_model()


# [END aiplatform_sdk_load_experiment_model_sample]
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import load_experiment_model_sample

import pytest

import test_constants as constants


@pytest.mark.usefixtures("mock_get_run")
def test_load_experiment_model_sample(
mock_get_experiment_model, mock_load_model, mock_ml_model
):

ml_model = load_experiment_model_sample.load_experiment_model_sample(
artifact_id=constants.EXPERIMENT_MODEL_ID,
project=constants.PROJECT,
location=constants.LOCATION,
)

mock_get_experiment_model.assert_called_once_with(
artifact_id=constants.EXPERIMENT_MODEL_ID,
project=constants.PROJECT,
location=constants.LOCATION,
)
mock_load_model.assert_called_once()

assert ml_model is mock_ml_model
49 changes: 49 additions & 0 deletions samples/model-builder/experiment_tracking/log_model_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Union

from google.cloud import aiplatform


# [START aiplatform_sdk_log_model_sample]
def log_model_sample(
experiment_name: str,
run_name: str,
project: str,
location: str,
model: Union[
"sklearn.base.BaseEstimator", "xgb.Booster", "tf.Module" # noqa: F821
],
artifact_id: Optional[str] = None,
uri: Optional[str] = None,
input_example: Optional[
Union[list, dict, "pd.DataFrame", "np.ndarray"] # noqa: F821
] = None, # noqa: F821
display_name: Optional[str] = None,
) -> None:
aiplatform.init(experiment=experiment_name, project=project, location=location)

aiplatform.start_run(run=run_name, resume=True)

aiplatform.log_model(
model=model,
artifact_id=artifact_id,
uri=uri,
input_example=input_example,
display_name=display_name,
)


# [END aiplatform_sdk_log_model_sample]
Loading

0 comments on commit 7997094

Please sign in to comment.