Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support model monitoring for batch prediction in Vertex SDK #1570

Merged
merged 8 commits into from
Aug 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 63 additions & 1 deletion google/cloud/aiplatform/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,13 @@ def create(
sync: bool = True,
create_request_timeout: Optional[float] = None,
batch_size: Optional[int] = None,
model_monitoring_objective_config: Optional[
"aiplatform.model_monitoring.ObjectiveConfig"
] = None,
model_monitoring_alert_config: Optional[
"aiplatform.model_monitoring.AlertConfig"
] = None,
analysis_instance_schema_uri: Optional[str] = None,
) -> "BatchPredictionJob":
"""Create a batch prediction job.

Expand Down Expand Up @@ -551,6 +558,23 @@ def create(
but too high value will result in a whole batch not fitting in a machine's memory,
and the whole operation will fail.
The default value is 64.
model_monitoring_objective_config (aiplatform.model_monitoring.ObjectiveConfig):
Optional. The objective config for model monitoring. Passing this parameter enables
monitoring on the model associated with this batch prediction job.
model_monitoring_alert_config (aiplatform.model_monitoring.EmailAlertConfig):
Optional. Configures how model monitoring alerts are sent to the user. Right now
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these docstrings copied from the source at https://github.com/googleapis/googleapis/tree/master/google/cloud/aiplatform? Will we be able to remember to update this when/if alerts other than email alert become supported?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They're not directly copied from GAPIC. But Jing's team also confirmed that there's no plans for additional alert configs.

only email alert is supported.
analysis_instance_schema_uri (str):
Optional. Only applicable if model_monitoring_objective_config is also passed.
This parameter specifies the YAML schema file uri describing the format of a single
instance that you want Tensorflow Data Validation (TFDV) to
analyze. If this field is empty, all the feature data types are
inferred from predict_instance_schema_uri, meaning that TFDV
will use the data in the exact format as prediction request/response.
If there are any data type differences between predict instance
and TFDV instance, this field can be used to override the schema.
For models trained with Vertex AI, this field must be set as all the
fields in predict instance formatted as string.
Returns:
(jobs.BatchPredictionJob):
Instantiated representation of the created batch prediction job.
Expand Down Expand Up @@ -601,7 +625,18 @@ def create(
f"{predictions_format} is not an accepted prediction format "
f"type. Please choose from: {constants.BATCH_PREDICTION_OUTPUT_STORAGE_FORMATS}"
)

# TODO: remove temporary import statements once model monitoring for batch prediction is GA
if model_monitoring_objective_config:
from google.cloud.aiplatform.compat.types import (
io_v1beta1 as gca_io_compat,
batch_prediction_job_v1beta1 as gca_bp_job_compat,
model_monitoring_v1beta1 as gca_model_monitoring_compat,
)
else:
from google.cloud.aiplatform.compat.types import (
io as gca_io_compat,
batch_prediction_job as gca_bp_job_compat,
)
gapic_batch_prediction_job = gca_bp_job_compat.BatchPredictionJob()

# Required Fields
Expand Down Expand Up @@ -688,6 +723,28 @@ def create(
)
)

# Model Monitoring
if model_monitoring_objective_config:
if model_monitoring_objective_config.drift_detection_config:
_LOGGER.info(
"Drift detection config is currently not supported for monitoring models associated with batch prediction jobs."
)
if model_monitoring_objective_config.explanation_config:
_LOGGER.info(
"XAI config is currently not supported for monitoring models associated with batch prediction jobs."
)
gapic_batch_prediction_job.model_monitoring_config = (
gca_model_monitoring_compat.ModelMonitoringConfig(
objective_configs=[
model_monitoring_objective_config.as_proto(config_for_bp=True)
],
alert_config=model_monitoring_alert_config.as_proto(
config_for_bp=True
),
analysis_instance_schema_uri=analysis_instance_schema_uri,
)
)

empty_batch_prediction_job = cls._empty_constructor(
project=project,
location=location,
Expand All @@ -702,6 +759,11 @@ def create(
sync=sync,
create_request_timeout=create_request_timeout,
)
# TODO: b/242108750
from google.cloud.aiplatform.compat.types import (
io as gca_io_compat,
batch_prediction_job as gca_bp_job_compat,
)

@classmethod
@base.optional_sync(return_input_arg="empty_batch_prediction_job")
Expand Down
24 changes: 21 additions & 3 deletions google/cloud/aiplatform/model_monitoring/alert.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,16 @@

from typing import Optional, List
from google.cloud.aiplatform_v1.types import (
model_monitoring as gca_model_monitoring,
model_monitoring as gca_model_monitoring_v1,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use compat for these imports?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compat by default imports from GA version of GAPIC, unless if DEFAULT_VERSION is set to v1beta1. I tried to see if the import aliases will index into the correct version by simply setting DEFAULT_VERSION = 'v1beta1' and then switching it back to v1 on an ad-hoc basis, but it doesn't dynamically index in the way I was hoping for. I think it's because the symbol table isn't automatically re-written unless if we explicitly re-import. So that's why I imported both v1 and v1beta1 versions explicitly.

)

# TODO: remove imports from v1beta1 once model monitoring for batch prediction is GA
from google.cloud.aiplatform_v1beta1.types import (
model_monitoring as gca_model_monitoring_v1beta1,
)

gca_model_monitoring = gca_model_monitoring_v1


class EmailAlertConfig:
def __init__(
Expand All @@ -40,8 +47,19 @@ def __init__(
self.enable_logging = enable_logging
self.user_emails = user_emails

def as_proto(self):
"""Returns EmailAlertConfig as a proto message."""
# TODO: remove config_for_bp parameter when model monitoring for batch prediction is GA
def as_proto(self, config_for_bp: bool = False):
"""Returns EmailAlertConfig as a proto message.

Args:
config_for_bp (bool):
Optional. Set this parameter to True if the config object
is used for model monitoring on a batch prediction job.
"""
if config_for_bp:
gca_model_monitoring = gca_model_monitoring_v1beta1
else:
gca_model_monitoring = gca_model_monitoring_v1
user_email_alert_config = (
gca_model_monitoring.ModelMonitoringAlertConfig.EmailAlertConfig(
user_emails=self.user_emails
Expand Down
96 changes: 63 additions & 33 deletions google/cloud/aiplatform/model_monitoring/objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,19 @@
from typing import Optional, Dict

from google.cloud.aiplatform_v1.types import (
io as gca_io,
ThresholdConfig as gca_threshold_config,
model_monitoring as gca_model_monitoring,
io as gca_io_v1,
model_monitoring as gca_model_monitoring_v1,
)

# TODO: b/242108750
from google.cloud.aiplatform_v1beta1.types import (
io as gca_io_v1beta1,
model_monitoring as gca_model_monitoring_v1beta1,
)

gca_model_monitoring = gca_model_monitoring_v1
gca_io = gca_io_v1

TF_RECORD = "tf-record"
CSV = "csv"
JSONL = "jsonl"
Expand Down Expand Up @@ -80,19 +88,20 @@ def __init__(
self.attribute_skew_thresholds = attribute_skew_thresholds
self.data_format = data_format
self.target_field = target_field
self.training_dataset = None

def as_proto(self):
"""Returns _SkewDetectionConfig as a proto message."""
skew_thresholds_mapping = {}
attribution_score_skew_thresholds_mapping = {}
if self.skew_thresholds is not None:
for key in self.skew_thresholds.keys():
skew_threshold = gca_threshold_config(value=self.skew_thresholds[key])
skew_threshold = gca_model_monitoring.ThresholdConfig(
value=self.skew_thresholds[key]
)
skew_thresholds_mapping[key] = skew_threshold
if self.attribute_skew_thresholds is not None:
for key in self.attribute_skew_thresholds.keys():
attribution_score_skew_threshold = gca_threshold_config(
attribution_score_skew_threshold = gca_model_monitoring.ThresholdConfig(
value=self.attribute_skew_thresholds[key]
)
attribution_score_skew_thresholds_mapping[
Expand Down Expand Up @@ -134,12 +143,16 @@ def as_proto(self):
attribution_score_drift_thresholds_mapping = {}
if self.drift_thresholds is not None:
for key in self.drift_thresholds.keys():
drift_threshold = gca_threshold_config(value=self.drift_thresholds[key])
drift_threshold = gca_model_monitoring.ThresholdConfig(
value=self.drift_thresholds[key]
)
drift_thresholds_mapping[key] = drift_threshold
if self.attribute_drift_thresholds is not None:
for key in self.attribute_drift_thresholds.keys():
attribution_score_drift_threshold = gca_threshold_config(
value=self.attribute_drift_thresholds[key]
attribution_score_drift_threshold = (
gca_model_monitoring.ThresholdConfig(
value=self.attribute_drift_thresholds[key]
)
)
attribution_score_drift_thresholds_mapping[
key
Expand Down Expand Up @@ -186,11 +199,49 @@ def __init__(
self.drift_detection_config = drift_detection_config
self.explanation_config = explanation_config

def as_proto(self):
"""Returns _ObjectiveConfig as a proto message."""
# TODO: b/242108750
def as_proto(self, config_for_bp: bool = False):
"""Returns _SkewDetectionConfig as a proto message.

Args:
config_for_bp (bool):
Optional. Set this parameter to True if the config object
is used for model monitoring on a batch prediction job.
"""
if config_for_bp:
gca_io = gca_io_v1beta1
gca_model_monitoring = gca_model_monitoring_v1beta1
else:
gca_io = gca_io_v1
gca_model_monitoring = gca_model_monitoring_v1
training_dataset = None
if self.skew_detection_config is not None:
training_dataset = self.skew_detection_config.training_dataset
training_dataset = (
gca_model_monitoring.ModelMonitoringObjectiveConfig.TrainingDataset(
target_field=self.skew_detection_config.target_field
)
)
if self.skew_detection_config.data_source.startswith("bq:/"):
training_dataset.bigquery_source = gca_io.BigQuerySource(
input_uri=self.skew_detection_config.data_source
)
elif self.skew_detection_config.data_source.startswith("gs:/"):
training_dataset.gcs_source = gca_io.GcsSource(
uris=[self.skew_detection_config.data_source]
)
if (
self.skew_detection_config.data_format is not None
and self.skew_detection_config.data_format
not in [TF_RECORD, CSV, JSONL]
):
raise ValueError(
"Unsupported value in skew detection config. `data_format` must be one of %s, %s, or %s"
% (TF_RECORD, CSV, JSONL)
)
training_dataset.data_format = self.skew_detection_config.data_format
else:
training_dataset.dataset = self.skew_detection_config.data_source

return gca_model_monitoring.ModelMonitoringObjectiveConfig(
training_dataset=training_dataset,
training_prediction_skew_detection_config=self.skew_detection_config.as_proto()
Expand Down Expand Up @@ -271,27 +322,6 @@ def __init__(
data_format,
)

training_dataset = (
gca_model_monitoring.ModelMonitoringObjectiveConfig.TrainingDataset(
target_field=target_field
)
)
if data_source.startswith("bq:/"):
training_dataset.bigquery_source = gca_io.BigQuerySource(
input_uri=data_source
)
elif data_source.startswith("gs:/"):
training_dataset.gcs_source = gca_io.GcsSource(uris=[data_source])
if data_format is not None and data_format not in [TF_RECORD, CSV, JSONL]:
raise ValueError(
"Unsupported value. `data_format` must be one of %s, %s, or %s"
% (TF_RECORD, CSV, JSONL)
)
training_dataset.data_format = data_format
else:
training_dataset.dataset = data_source
self.training_dataset = training_dataset


class DriftDetectionConfig(_DriftDetectionConfig):
"""A class that configures prediction drift detection for models deployed to an endpoint.
Expand Down
36 changes: 27 additions & 9 deletions tests/system/aiplatform/test_model_monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,15 @@
from google.api_core import exceptions as core_exceptions
from tests.system.aiplatform import e2e_base

from google.cloud.aiplatform_v1.types import (
io as gca_io,
model_monitoring as gca_model_monitoring,
)

# constants used for testing
USER_EMAIL = ""
MODEL_NAME = "churn"
MODEL_NAME2 = "churn2"
MODEL_DISPLAYNAME_KEY = "churn"
MODEL_DISPLAYNAME_KEY2 = "churn2"
IMAGE = "us-docker.pkg.dev/cloud-aiplatform/prediction/tf2-cpu.2-5:latest"
ENDPOINT = "us-central1-aiplatform.googleapis.com"
CHURN_MODEL_PATH = "gs://mco-mm/churn"
Expand Down Expand Up @@ -139,7 +144,7 @@ def temp_endpoint(self, shared_state):
)

model = aiplatform.Model.upload(
display_name=self._make_display_name(key=MODEL_NAME),
display_name=self._make_display_name(key=MODEL_DISPLAYNAME_KEY),
artifact_uri=CHURN_MODEL_PATH,
serving_container_image_uri=IMAGE,
)
Expand All @@ -157,19 +162,19 @@ def temp_endpoint_with_two_models(self, shared_state):
)

model1 = aiplatform.Model.upload(
display_name=self._make_display_name(key=MODEL_NAME),
display_name=self._make_display_name(key=MODEL_DISPLAYNAME_KEY),
artifact_uri=CHURN_MODEL_PATH,
serving_container_image_uri=IMAGE,
)

model2 = aiplatform.Model.upload(
display_name=self._make_display_name(key=MODEL_NAME),
display_name=self._make_display_name(key=MODEL_DISPLAYNAME_KEY2),
artifact_uri=CHURN_MODEL_PATH,
serving_container_image_uri=IMAGE,
)
shared_state["resources"] = [model1, model2]
endpoint = aiplatform.Endpoint.create(
display_name=self._make_display_name(key=MODEL_NAME)
display_name=self._make_display_name(key=MODEL_DISPLAYNAME_KEY)
)
endpoint.deploy(
model=model1, machine_type="n1-standard-2", traffic_percentage=100
Expand Down Expand Up @@ -224,7 +229,14 @@ def test_mdm_one_model_one_valid_config(self, shared_state):
gca_obj_config = gapic_job.model_deployment_monitoring_objective_configs[
0
].objective_config
assert gca_obj_config.training_dataset == skew_config.training_dataset

expected_training_dataset = (
gca_model_monitoring.ModelMonitoringObjectiveConfig.TrainingDataset(
bigquery_source=gca_io.BigQuerySource(input_uri=DATASET_BQ_URI),
target_field=TARGET,
)
)
assert gca_obj_config.training_dataset == expected_training_dataset
assert (
gca_obj_config.training_prediction_skew_detection_config
== skew_config.as_proto()
Expand Down Expand Up @@ -297,12 +309,18 @@ def test_mdm_two_models_two_valid_configs(self, shared_state):
)
assert gapic_job.model_monitoring_alert_config.enable_logging

expected_training_dataset = (
gca_model_monitoring.ModelMonitoringObjectiveConfig.TrainingDataset(
bigquery_source=gca_io.BigQuerySource(input_uri=DATASET_BQ_URI),
target_field=TARGET,
)
)

for config in gapic_job.model_deployment_monitoring_objective_configs:
gca_obj_config = config.objective_config
deployed_model_id = config.deployed_model_id
assert (
gca_obj_config.training_dataset
== all_configs[deployed_model_id].skew_detection_config.training_dataset
gca_obj_config.as_proto().training_dataset == expected_training_dataset
)
assert (
gca_obj_config.training_prediction_skew_detection_config
Expand Down
Loading