Skip to content

Commit

Permalink
feat: Add notification_channels field to model monitoring alert config.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 593812234
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Dec 26, 2023
1 parent 9a8e1ca commit bb228ce
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 33 deletions.
10 changes: 8 additions & 2 deletions google/cloud/aiplatform/model_monitoring/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,23 @@
# limitations under the License.
#

from google.cloud.aiplatform.model_monitoring.alert import EmailAlertConfig
from google.cloud.aiplatform.model_monitoring.alert import (
AlertConfig,
EmailAlertConfig,
)
from google.cloud.aiplatform.model_monitoring.objective import (
SkewDetectionConfig,
DriftDetectionConfig,
ExplanationConfig,
ObjectiveConfig,
)
from google.cloud.aiplatform.model_monitoring.sampling import RandomSampleConfig
from google.cloud.aiplatform.model_monitoring.sampling import (
RandomSampleConfig,
)
from google.cloud.aiplatform.model_monitoring.schedule import ScheduleConfig

__all__ = (
"AlertConfig",
"EmailAlertConfig",
"SkewDetectionConfig",
"DriftDetectionConfig",
Expand Down
70 changes: 47 additions & 23 deletions google/cloud/aiplatform/model_monitoring/alert.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,56 +15,80 @@
# limitations under the License.
#

from typing import Optional, List
from typing import List, Optional
from google.cloud.aiplatform_v1.types import (
model_monitoring as gca_model_monitoring_v1,
)

# TODO(b/242108750): remove temporary logic once model monitoring for batch prediction is GA
# TODO(b/242108750): remove temporary logic once model monitoring for
# batch prediction is GA.
from google.cloud.aiplatform_v1beta1.types import (
model_monitoring as gca_model_monitoring_v1beta1,
)

gca_model_monitoring = gca_model_monitoring_v1


class EmailAlertConfig:
class AlertConfig:
def __init__(
self, user_emails: List[str] = [], enable_logging: Optional[bool] = False
self,
user_emails: List[str] = [],
enable_logging: Optional[bool] = False,
notification_channels: List[str] = [],
):
"""Initializer for EmailAlertConfig.
"""Initializer for AlertConfig.
Args:
user_emails (List[str]):
The email addresses to send the alert to.
enable_logging (bool):
Optional. Defaults to False. Streams detected anomalies to Cloud Logging. The anomalies will be
put into json payload encoded from proto
[google.cloud.aiplatform.logging.ModelMonitoringAnomaliesLogEntry][].
This can be further sync'd to Pub/Sub or any other services
supported by Cloud Logging.
user_emails (List[str]): The email addresses to send the alert to.
enable_logging (bool): Optional. Defaults to False. Streams detected
anomalies to Cloud Logging. The anomalies will be put into json
payload encoded from proto
[google.cloud.aiplatform.logging.ModelMonitoringAnomaliesLogEntry][].
This can be further sync'd to Pub/Sub or any other services supported
by Cloud Logging.
notification_channels (List[str]): The Cloud notification channels to
send the alert to.
"""
self.enable_logging = enable_logging
self.user_emails = user_emails
self.enable_logging = enable_logging
self.notification_channels = notification_channels
self._config_for_bp = False

# TODO(b/242108750): remove temporary logic once model monitoring for batch prediction is GA
def as_proto(self) -> gca_model_monitoring.ModelMonitoringAlertConfig:
"""Converts EmailAlertConfig to a proto message.
"""Converts AlertConfig to a proto message.
Returns:
The GAPIC representation of the email alert config.
The GAPIC representation of the alert config.
"""
# TODO(b/242108750): remove temporary logic once model monitoring for
# batch prediction is GA.
if self._config_for_bp:
gca_model_monitoring = gca_model_monitoring_v1beta1
else:
gca_model_monitoring = gca_model_monitoring_v1
user_email_alert_config = (
gca_model_monitoring.ModelMonitoringAlertConfig.EmailAlertConfig(
user_emails=self.user_emails
)
)

return gca_model_monitoring.ModelMonitoringAlertConfig(
email_alert_config=user_email_alert_config,
email_alert_config=gca_model_monitoring.ModelMonitoringAlertConfig.EmailAlertConfig(
user_emails=self.user_emails
),
enable_logging=self.enable_logging,
notification_channels=self.notification_channels,
)


class EmailAlertConfig(AlertConfig):
def __init__(
self, user_emails: List[str] = [], enable_logging: Optional[bool] = False
):
"""Initializer for EmailAlertConfig.
Args:
user_emails (List[str]): The email addresses to send the alert to.
enable_logging (bool): Optional. Defaults to False. Streams detected
anomalies to Cloud Logging. The anomalies will be put into json
payload encoded from proto
[google.cloud.aiplatform.logging.ModelMonitoringAnomaliesLogEntry][].
This can be further sync'd to Pub/Sub or any other services supported
by Cloud Logging.
"""
super().__init__(user_emails=user_emails, enable_logging=enable_logging)
49 changes: 42 additions & 7 deletions tests/system/aiplatform/test_model_monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

# constants used for testing
USER_EMAIL = "[email protected]"
NOTIFICATION_CHANNEL = "projects/123/notificationChannels/456"
PERMANENT_CHURN_MODEL_ID = "5295507484113371136"
CHURN_MODEL_PATH = "gs://mco-mm/churn"
DEFAULT_INPUT = {
Expand Down Expand Up @@ -90,10 +91,16 @@
# global test constants
sampling_strategy = model_monitoring.RandomSampleConfig(sample_rate=LOG_SAMPLE_RATE)

alert_config = model_monitoring.EmailAlertConfig(
email_alert_config = model_monitoring.EmailAlertConfig(
user_emails=[USER_EMAIL], enable_logging=True
)

alert_config = model_monitoring.AlertConfig(
user_emails=[USER_EMAIL],
enable_logging=True,
notification_channels=[NOTIFICATION_CHANNEL],
)

schedule_config = model_monitoring.ScheduleConfig(monitor_interval=MONITOR_INTERVAL)

skew_config = model_monitoring.SkewDetectionConfig(
Expand Down Expand Up @@ -149,7 +156,7 @@ def test_mdm_two_models_one_valid_config(self, shared_state):
display_name=self._make_display_name(key=JOB_NAME),
logging_sampling_strategy=sampling_strategy,
schedule_config=schedule_config,
alert_config=alert_config,
alert_config=email_alert_config,
objective_configs=objective_config,
create_request_timeout=3600,
project=e2e_base._PROJECT,
Expand Down Expand Up @@ -211,7 +218,7 @@ def test_mdm_pause_and_update_config(self, shared_state):
display_name=self._make_display_name(key=JOB_NAME),
logging_sampling_strategy=sampling_strategy,
schedule_config=schedule_config,
alert_config=alert_config,
alert_config=email_alert_config,
objective_configs=model_monitoring.ObjectiveConfig(
drift_detection_config=drift_config
),
Expand Down Expand Up @@ -284,7 +291,7 @@ def test_mdm_two_models_two_valid_configs(self, shared_state):
display_name=self._make_display_name(key=JOB_NAME),
logging_sampling_strategy=sampling_strategy,
schedule_config=schedule_config,
alert_config=alert_config,
alert_config=email_alert_config,
objective_configs=all_configs,
create_request_timeout=3600,
project=e2e_base._PROJECT,
Expand Down Expand Up @@ -338,7 +345,7 @@ def test_mdm_invalid_config_incorrect_model_id(self, shared_state):
display_name=self._make_display_name(key=JOB_NAME),
logging_sampling_strategy=sampling_strategy,
schedule_config=schedule_config,
alert_config=alert_config,
alert_config=email_alert_config,
objective_configs=objective_config,
create_request_timeout=3600,
project=e2e_base._PROJECT,
Expand All @@ -358,7 +365,7 @@ def test_mdm_invalid_config_xai(self, shared_state):
display_name=self._make_display_name(key=JOB_NAME),
logging_sampling_strategy=sampling_strategy,
schedule_config=schedule_config,
alert_config=alert_config,
alert_config=email_alert_config,
objective_configs=objective_config,
create_request_timeout=3600,
project=e2e_base._PROJECT,
Expand Down Expand Up @@ -388,7 +395,7 @@ def test_mdm_two_models_invalid_configs_xai(self, shared_state):
display_name=self._make_display_name(key=JOB_NAME),
logging_sampling_strategy=sampling_strategy,
schedule_config=schedule_config,
alert_config=alert_config,
alert_config=email_alert_config,
objective_configs=all_configs,
create_request_timeout=3600,
project=e2e_base._PROJECT,
Expand All @@ -399,3 +406,31 @@ def test_mdm_two_models_invalid_configs_xai(self, shared_state):
"`explanation_config` should only be enabled if the model has `explanation_spec populated"
in str(e.value)
)

def test_mdm_notification_channel_alert_config(self, shared_state):
self.endpoint = shared_state["resources"][0]
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)
# test model monitoring configurations
job = aiplatform.ModelDeploymentMonitoringJob.create(
display_name=self._make_display_name(key=JOB_NAME),
logging_sampling_strategy=sampling_strategy,
schedule_config=schedule_config,
alert_config=alert_config,
objective_configs=objective_config,
create_request_timeout=3600,
project=e2e_base._PROJECT,
location=e2e_base._LOCATION,
endpoint=self.endpoint,
)

gapic_job = job._gca_resource
assert (
gapic_job.model_monitoring_alert_config.email_alert_config.user_emails
== [USER_EMAIL]
)
assert gapic_job.model_monitoring_alert_config.enable_logging
assert gapic_job.model_monitoring_alert_config.notification_channels == [
NOTIFICATION_CHANNEL
]

job.delete()
18 changes: 17 additions & 1 deletion tests/unit/aiplatform/test_model_monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
_TEST_DRIFT_TRESHOLD = {"key": 0.2}
_TEST_EMAIL1 = "test1"
_TEST_EMAIL2 = "test2"
_TEST_NOTIFICATION_CHANNEL = "projects/123/notificationChannels/456"
_TEST_VALID_DATA_FORMATS = ["tf-record", "csv", "jsonl"]
_TEST_SAMPLING_RATE = 0.8
_TEST_MONITORING_INTERVAL = 1
Expand Down Expand Up @@ -105,10 +106,16 @@ def test_valid_configs(
monitor_interval=_TEST_MONITORING_INTERVAL
)

alert_config = model_monitoring.EmailAlertConfig(
email_alert_config = model_monitoring.EmailAlertConfig(
user_emails=[_TEST_EMAIL1, _TEST_EMAIL2]
)

alert_config = model_monitoring.AlertConfig(
user_emails=[_TEST_EMAIL1, _TEST_EMAIL2],
enable_logging=True,
notification_channels=[_TEST_NOTIFICATION_CHANNEL],
)

prediction_drift_config = model_monitoring.DriftDetectionConfig(
drift_thresholds=_TEST_DRIFT_TRESHOLD
)
Expand Down Expand Up @@ -149,8 +156,17 @@ def test_valid_configs(
== prediction_drift_config.as_proto()
)
assert objective_config.as_proto().explanation_config == xai_config.as_proto()
assert (
_TEST_EMAIL1 in email_alert_config.as_proto().email_alert_config.user_emails
)
assert (
_TEST_EMAIL2 in email_alert_config.as_proto().email_alert_config.user_emails
)
assert _TEST_EMAIL1 in alert_config.as_proto().email_alert_config.user_emails
assert _TEST_EMAIL2 in alert_config.as_proto().email_alert_config.user_emails
assert (
_TEST_NOTIFICATION_CHANNEL in alert_config.as_proto().notification_channels
)
assert (
random_sample_config.as_proto().random_sample_config.sample_rate
== _TEST_SAMPLING_RATE
Expand Down

0 comments on commit bb228ce

Please sign in to comment.