Skip to content

Commit

Permalink
🦉 Updates from OwlBot post-processor
Browse files Browse the repository at this point in the history
  • Loading branch information
gcf-owl-bot[bot] authored and rosiezou committed May 26, 2022
1 parent 018ff9a commit 21e0027
Show file tree
Hide file tree
Showing 7 changed files with 198 additions and 139 deletions.
10 changes: 6 additions & 4 deletions google/cloud/aiplatform/compat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,10 @@
types.model = types.model_v1beta1
types.model_evaluation = types.model_evaluation_v1beta1
types.model_evaluation_slice = types.model_evaluation_slice_v1beta1
types.model_deployment_monitoring_job = types.model_deployment_monitoring_job_v1beta1,
types.model_monitoring = types.model_monitoring_v1beta1,
types.model_deployment_monitoring_job = (
types.model_deployment_monitoring_job_v1beta1,
)
types.model_monitoring = (types.model_monitoring_v1beta1,)
types.model_service = types.model_service_v1beta1
types.operation = types.operation_v1beta1
types.pipeline_job = types.pipeline_job_v1beta1
Expand Down Expand Up @@ -176,8 +178,8 @@
types.model = types.model_v1
types.model_evaluation = types.model_evaluation_v1
types.model_evaluation_slice = types.model_evaluation_slice_v1
types.model_deployment_monitoring_job = types.model_deployment_monitoring_job_v1,
types.model_monitoring = types.model_monitoring_v1,
types.model_deployment_monitoring_job = (types.model_deployment_monitoring_job_v1,)
types.model_monitoring = (types.model_monitoring_v1,)
types.model_service = types.model_service_v1
types.operation = types.operation_v1
types.pipeline_job = types.pipeline_job_v1
Expand Down
124 changes: 70 additions & 54 deletions google/cloud/aiplatform/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
hyperparameter_tuning_job as gca_hyperparameter_tuning_job_compat,
machine_resources as gca_machine_resources_compat,
manual_batch_tuning_parameters as gca_manual_batch_tuning_parameters_compat,

study as gca_study_compat,
model_deployment_monitoring_job as gca_model_deployment_monitoring_job_compat,
model_monitoring as gca_model_monitoring_compat,
Expand Down Expand Up @@ -384,7 +383,6 @@ def create(
sync: bool = True,
create_request_timeout: Optional[float] = None,
batch_size: Optional[int] = None,

) -> "BatchPredictionJob":
"""Create a batch prediction job.
Expand Down Expand Up @@ -671,7 +669,7 @@ def create(
gapic_batch_prediction_job.manual_batch_tuning_parameters = (
manual_batch_tuning_parameters
)

# User Labels
gapic_batch_prediction_job.labels = labels

Expand Down Expand Up @@ -1923,7 +1921,7 @@ def trials(self) -> List[gca_study_compat.Trial]:

class ModelDeploymentMonitoringJob(_Job):
"""Vertex AI Model Deployment Monitoring Job.
This class should be used in conjunction with the Endpoint class
in order to configure model monitoring for deployed models.
"""
Expand Down Expand Up @@ -1952,10 +1950,12 @@ def __init__(
credentials=credentials,
)


def _parse_configs(objective_configs: Union[
def _parse_configs(
objective_configs: Union[
model_monitoring.EndpointObjectiveConfig,
Dict[str, model_monitoring.EndpointObjectiveConfig]]):
Dict[str, model_monitoring.EndpointObjectiveConfig],
]
):

all_configs = {}
all_models = []
Expand All @@ -1967,28 +1967,33 @@ def _parse_configs(objective_configs: Union[
all_models.append(model)

## when same objective config is applied to ALL models
if isinstance(objective_configs, model_monitoring.EndpointObjectiveConfig) and deployed_model_ids is None:
if (
isinstance(objective_configs, model_monitoring.EndpointObjectiveConfig)
and deployed_model_ids is None
):
for model in all_models:
all_configs[model] = objective_configs

## when same objective config is applied to SOME models
elif isinstance(objective_configs, model_monitoring.EndpointObjectiveConfig) and isinstance(deployed_model_ids, List):
elif isinstance(
objective_configs, model_monitoring.EndpointObjectiveConfig
) and isinstance(deployed_model_ids, List):
for model in deployed_model_ids:
assert(model in all_models)
assert model in all_models
all_configs[model] = objective_configs

## when different objective configs are applied to EACH model
elif isinstance(objective_configs, Dict) and deployed_model_ids is None:
assert(all(model in all_models for model in objective_configs.keys()))
assert all(model in all_models for model in objective_configs.keys())
all_configs = objective_configs

mdm_objective_config_seq = []
for key in all_configs.keys():
mdm_objective_config_seq.append(
gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringObjectiveConfig(
deployed_model_id = key,
objective_config = all_configs[key]
))
deployed_model_id=key, objective_config=all_configs[key]
)
)
return mdm_objective_config_seq

@classmethod
Expand All @@ -1998,7 +2003,8 @@ def create(
endpoint: Union[str, "models.Endpoint"],
objective_configs: Union[
model_monitoring.EndpointObjectiveConfig,
Dict[str, model_monitoring.EndpointObjectiveConfig]],
Dict[str, model_monitoring.EndpointObjectiveConfig],
],
logging_sampling_strategy: model_monitoring.RandomSampleConfig,
monitor_interval: int,
schedule_config: model_monitoring.ScheduleConfig,
Expand Down Expand Up @@ -2057,7 +2063,7 @@ def create(
schedule_config (model_monitoring.schedule.ScheduleConfig):
Configures model monitoring job scheduling interval in hours.
This defines how often the monitoring jobs are triggered.
alert_config (model_monitoring.alert.EmailAlertConfig):
Optional. Configures how alerts are sent to the user. Right now
only email alert is supported.
Expand Down Expand Up @@ -2111,7 +2117,7 @@ def create(
underscores and dashes. International characters
are allowed. See https://goo.gl/xmQnxf for more information
and examples of labels.
encryption_spec_key_name (str):
Optional. Customer-managed encryption key spec for a
ModelDeploymentMonitoringJob. If set, this
Expand Down Expand Up @@ -2142,49 +2148,53 @@ def create(

if stats_anomalies_base_directory:
stats_anomalies_base_directory = gca_io_compat.GcsDestination(
output_uri_prefix = stats_anomalies_base_directory)

output_uri_prefix=stats_anomalies_base_directory
)

if encryption_spec_key_name:
encryption_spec_key_name = gca_encryption_spec_compat.EncryptionSpec(
kms_key_name = encryption_spec_key_name)
kms_key_name=encryption_spec_key_name
)

mdm_objective_config_seq = cls._parse_configs(objective_configs)

self._gca_resource = gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringJob(
name=self.model_deployment_monitoring_job_name,
display_name=display_name,
endpoint=endpoint,
model_deployment_monitoring_objective_configs=mdm_objective_config_seq,
logging_sampling_strategy=logging_sampling_strategy,
model_deployment_monitoring_schedule_config=schedule_config,
model_monitoring_alert_config=alerting_config,
predict_instance_schema_uri=predict_instance_schema_uri,
analysis_instance_schema_uri=analysis_instance_schema_uri,
sample_predict_instance = sample_predict_instance,
stats_anomalies_base_directory = stats_anomalies_base_directory,
enable_monitoring_pipeline_logs = enable_monitoring_pipeline_logs,
labels = labels,
encryption_spec = encryption_spec_key_name
self._gca_resource = (
gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringJob(
name=self.model_deployment_monitoring_job_name,
display_name=display_name,
endpoint=endpoint,
model_deployment_monitoring_objective_configs=mdm_objective_config_seq,
logging_sampling_strategy=logging_sampling_strategy,
model_deployment_monitoring_schedule_config=schedule_config,
model_monitoring_alert_config=alerting_config,
predict_instance_schema_uri=predict_instance_schema_uri,
analysis_instance_schema_uri=analysis_instance_schema_uri,
sample_predict_instance=sample_predict_instance,
stats_anomalies_base_directory=stats_anomalies_base_directory,
enable_monitoring_pipeline_logs=enable_monitoring_pipeline_logs,
labels=labels,
encryption_spec=encryption_spec_key_name,
)
)

api_client = cls.api_client
mdm_job = api_client.create_model_deployment_monitoring_job(
parent = parent,
model_deployment_monitoring_job = cls._gca_resource
parent=parent, model_deployment_monitoring_job=cls._gca_resource
)
return mdm_job


def update(
self,
timeout: float = None,
metadata: Sequence[Tuple[str, str]] = (),
display_name: Optional[str] = None,
objective_configs: Optional[Union[
model_monitoring.EndpointObjectiveConfig,
Dict[str, model_monitoring.EndpointObjectiveConfig]]] = None,
logging_sampling_strategy: \
Optional[model_monitoring.RandomSampleConfig] = None,
objective_configs: Optional[
Union[
model_monitoring.EndpointObjectiveConfig,
Dict[str, model_monitoring.EndpointObjectiveConfig],
]
] = None,
logging_sampling_strategy: Optional[model_monitoring.RandomSampleConfig] = None,
monitor_interval: Optional[int] = None,
deployed_model_ids: Optional[List[str]] = None,
schedule_config: Optional[model_monitoring.ScheduleConfig] = None,
Expand All @@ -2200,11 +2210,12 @@ def update(
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
sync: bool = True,
sync: bool = True,
) -> "ModelDeploymentMonitoringJob":
""""""
current_job = self.api_client.get_model_deployment_monitoring_job(
name = slef.model_deployment_monitoring_job_name)
name=slef.model_deployment_monitoring_job_name
)
update_mask: List[str] = []
if display_name:
update_mask.append("display_name")
Expand All @@ -2226,27 +2237,32 @@ def update(
current_job.log_ttl = bigquery_tables_log_ttl
if enable_monitoring_pipeline_logs:
update_mask.append("enable_monitoring_pipeline_logs")
current_job.enable_monitoring_pipeline_logs = enable_monitoring_pipeline_logs
current_job.enable_monitoring_pipeline_logs = (
enable_monitoring_pipeline_logs
)
if objective_configs:
update_mask.append("model_deployment_monitoring_objective_configs")
current_job.model_deployment_monitoring_objective_configs = self._parse_configs(objective_configs)
current_job.model_deployment_monitoring_objective_configs = (
self._parse_configs(objective_configs)
)
self.api_client.update_model_deployment_monitoring_job(
model_deployment_monitoring_job = current_job,
update_mask = update_mask
model_deployment_monitoring_job=current_job, update_mask=update_mask
)


def pause(self) -> "ModelDeploymentMonitoringJob":
""""""
self.api_client.pause_model_deployment_monitoring_job(
self.model_deployment_monitoring_job_name)
self.model_deployment_monitoring_job_name
)

def resume(self) -> "ModelDeploymentMonitoringJob":
""""""
self.api_client.resume_model_deployment_monitoring_job(
self.model_deployment_monitoring_job_name)
self.model_deployment_monitoring_job_name
)

def delete(self) -> "ModelDeploymentMonitoringJob":
""""""
self.api_client.delete_model_deployment_monitoring_job(
self.model_deployment_monitoring_job_name)
self.model_deployment_monitoring_job_name
)
9 changes: 7 additions & 2 deletions google/cloud/aiplatform/model_monitoring/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@
#

from google.cloud.aiplatform.model_monitoring.alert import EmailAlertConfig
from google.cloud.aiplatform.model_monitoring.objective import EndpointSkewDetectionConfig, EndpointDriftDetectionConfig, EndpointExplanationConfig, EndpointObjectiveConfig
from google.cloud.aiplatform.model_monitoring.objective import (
EndpointSkewDetectionConfig,
EndpointDriftDetectionConfig,
EndpointExplanationConfig,
EndpointObjectiveConfig,
)
from google.cloud.aiplatform.model_monitoring.sampling import RandomSampleConfig
from google.cloud.aiplatform.model_monitoring.schedule import ScheduleConfig

Expand All @@ -27,5 +32,5 @@
"EndpointExplanationConfig",
"EndpointObjectiveConfig",
"RandomSampleConfig",
"ScheduleConfig"
"ScheduleConfig",
)
35 changes: 16 additions & 19 deletions google/cloud/aiplatform/model_monitoring/alert.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,20 @@

import abc
from typing import Optional, List
from google.cloud.aiplatform.compat.types import model_monitoring as gca_model_monitoring
from google.cloud.aiplatform.compat.types import (
model_monitoring as gca_model_monitoring,
)


class _AlertConfig(abc.ABC):
"""An abstract class for setting model monitoring alert config"""
def __init__(
self,
enable_logging: Optional[bool] = None
):

def __init__(self, enable_logging: Optional[bool] = None):
self.enable_logging = enable_logging


class EmailAlertConfig(_AlertConfig):
def __init__(
self,
user_emails: List[str],
enable_logging: Optional[bool] = None
):
def __init__(self, user_emails: List[str], enable_logging: Optional[bool] = None):
"""Initializer for EmailAlertConfig
Args:
Expand All @@ -49,16 +46,16 @@ def __init__(
Returns:
An instance of EmailAlertConfig
"""
super().__init__(
enable_logging = self.enable_logging
)
super().__init__(enable_logging=self.enable_logging)
self.user_emails = user_emails

def as_proto(self):
user_email_alert_config = gca_model_monitoring.ModelMonitoringAlertConfig.EmailAlertConfig(
user_emails = self.user_emails
user_email_alert_config = (
gca_model_monitoring.ModelMonitoringAlertConfig.EmailAlertConfig(
user_emails=self.user_emails
)
)
return gca_model_monitoring.ModelMonitoringAlertConfig(
email_alert_config = user_email_alert_config,
enable_logging = self.enable_logging
email_alert_config=user_email_alert_config,
enable_logging=self.enable_logging,
)
Loading

0 comments on commit 21e0027

Please sign in to comment.