Skip to content

Commit

Permalink
feat: add MLMD schema class ExperimentModel
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 501468901
  • Loading branch information
jaycee-li authored and copybara-github committed Jan 12, 2023
1 parent 6fa93a4 commit 94b2f29
Show file tree
Hide file tree
Showing 4 changed files with 408 additions and 8 deletions.
24 changes: 23 additions & 1 deletion google/cloud/aiplatform/metadata/schema/base_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ def _init_with_resource_name(
self,
*,
artifact_name: str,
metadata_store_id: str = "default",
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
):

"""Initializes the Artifact instance using an existing resource.
Expand All @@ -115,13 +119,31 @@ def _init_with_resource_name(
artifact_name (str):
Artifact name with the following format, this is globally unique in a metadataStore:
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/artifacts/<resource_id>.
metadata_store_id (str):
Optional. MetadataStore to retrieve Artifact from. If not set, metadata_store_id is set to "default".
If artifact_name is a fully-qualified resource, its metadata_store_id overrides this one.
project (str):
Optional. Project to retrieve the artifact from. If not set, project
set in aiplatform.init will be used.
location (str):
Optional. Location to retrieve the Artifact from. If not set, location
set in aiplatform.init will be used.
credentials (auth_credentials.Credentials):
Optional. Custom credentials to use to retrieve this Artifact. Overrides
credentials set in aiplatform.init.
"""
# Add User Agent Header for metrics tracking if one is not specified
# If one is already specified this call was initiated by a sub class.
if not base_constants.USER_AGENT_SDK_COMMAND:
base_constants.USER_AGENT_SDK_COMMAND = "aiplatform.metadata.schema.base_artifact.BaseArtifactSchema._init_with_resource_name"

super(BaseArtifactSchema, self).__init__(artifact_name=artifact_name)
super(BaseArtifactSchema, self).__init__(
artifact_name=artifact_name,
metadata_store_id=metadata_store_id,
project=project,
location=location,
credentials=credentials,
)

def create(
self,
Expand Down
157 changes: 156 additions & 1 deletion google/cloud/aiplatform/metadata/schema/google/artifact_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import copy
from typing import Optional, Dict, List

from google.auth import credentials as auth_credentials
from google.cloud.aiplatform.compat.types import artifact as gca_artifact
from google.cloud.aiplatform.metadata.schema import base_artifact
from google.cloud.aiplatform.metadata.schema import utils
Expand Down Expand Up @@ -359,7 +360,6 @@ def __init__(
extended_metadata = copy.deepcopy(metadata) if metadata else {}
if aggregation_type:
if aggregation_type not in _CLASSIFICATION_METRICS_AGGREGATION_TYPE:
## Todo: add negative test case for this
raise ValueError(
"aggregation_type can only be 'AGGREGATION_TYPE_UNSPECIFIED', 'MACRO_AVERAGE', or 'MICRO_AVERAGE'."
)
Expand Down Expand Up @@ -583,3 +583,158 @@ def __init__(
metadata=extended_metadata,
state=state,
)


class ExperimentModel(base_artifact.BaseArtifactSchema):
"""An artifact representing a Vertex Experiment Model."""

schema_title = "google.ExperimentModel"

RESERVED_METADATA_KEYS = [
"frameworkName",
"frameworkVersion",
"modelFile",
"modelClass",
"predictSchemata",
]

def __init__(
self,
*,
framework_name: str,
framework_version: str,
model_file: str,
uri: str,
model_class: Optional[str] = None,
predict_schemata: Optional[utils.PredictSchemata] = None,
artifact_id: Optional[str] = None,
display_name: Optional[str] = None,
schema_version: Optional[str] = None,
description: Optional[str] = None,
metadata: Optional[Dict] = None,
state: Optional[gca_artifact.Artifact.State] = gca_artifact.Artifact.State.LIVE,
):
"""Args:
framework_name (str):
Required. The name of the model's framework. E.g., 'sklearn'
framework_version (str):
Required. The version of the model's framework. E.g., '1.1.0'
model_file (str):
Required. The file name of the model. E.g., 'model.pkl'
uri (str):
Required. The uniform resource identifier of the model artifact directory.
model_class (str):
Optional. The class name of the model. E.g., 'sklearn.linear_model._base.LinearRegression'
predict_schemata (PredictSchemata):
Optional. An instance of PredictSchemata which holds instance, parameter and prediction schema uris.
artifact_id (str):
Optional. The <resource_id> portion of the Artifact name with
the format. This is globally unique in a metadataStore:
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/artifacts/<resource_id>.
display_name (str):
Optional. The user-defined name of the Artifact.
schema_version (str):
Optional. schema_version specifies the version used by the Artifact.
If not set, defaults to use the latest version.
description (str):
Optional. Describes the purpose of the Artifact to be created.
metadata (Dict):
Optional. Contains the metadata information that will be stored in the Artifact.
state (google.cloud.gapic.types.Artifact.State):
Optional. The state of this Artifact. This is a
property of the Artifact, and does not imply or
apture any ongoing process. This property is
managed by clients (such as Vertex AI
Pipelines), and the system does not prescribe or
check the validity of state transitions.
"""
if metadata:
for k in metadata:
if k in self.RESERVED_METADATA_KEYS:
raise ValueError(f"'{k}' is a system reserved key in metadata.")
extended_metadata = copy.deepcopy(metadata)
else:
extended_metadata = {}
extended_metadata["frameworkName"] = framework_name
extended_metadata["frameworkVersion"] = framework_version
extended_metadata["modelFile"] = model_file
if model_class is not None:
extended_metadata["modelClass"] = model_class
if predict_schemata is not None:
extended_metadata["predictSchemata"] = predict_schemata.to_dict()

super().__init__(
uri=uri,
artifact_id=artifact_id,
display_name=display_name,
schema_version=schema_version,
description=description,
metadata=extended_metadata,
state=state,
)

@classmethod
def get(
cls,
artifact_id: str,
*,
metadata_store_id: str = "default",
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> "ExperimentModel":
"""Retrieves an existing ExperimentModel artifact given an artifact id.
Args:
artifact_id (str):
Required. An artifact id of the ExperimentModel artifact.
metadata_store_id (str):
Optional. MetadataStore to retrieve Artifact from. If not set, metadata_store_id is set to "default".
If artifact_id is a fully-qualified resource name, its metadata_store_id overrides this one.
project (str):
Optional. Project to retrieve the artifact from. If not set, project
set in aiplatform.init will be used.
location (str):
Optional. Location to retrieve the Artifact from. If not set, location
set in aiplatform.init will be used.
credentials (auth_credentials.Credentials):
Optional. Custom credentials to use to retrieve this Artifact. Overrides
credentials set in aiplatform.init.
Returns:
An ExperimentModel class that represents an Artifact resource.
Raises:
ValueError: if artifact's schema title is not 'google.ExperimentModel'.
"""
experiment_model = ExperimentModel(
framework_name="",
framework_version="",
model_file="",
uri="",
)
experiment_model._init_with_resource_name(
artifact_name=artifact_id,
metadata_store_id=metadata_store_id,
project=project,
location=location,
credentials=credentials,
)
if experiment_model.schema_title != cls.schema_title:
raise ValueError(
f"The schema title of the artifact must be {cls.schema_title}."
f"Got {experiment_model.schema_title}."
)
return experiment_model

@property
def framework_name(self) -> Optional[str]:
return self.metadata.get("frameworkName")

@property
def framework_version(self) -> Optional[str]:
return self.metadata.get("frameworkVersion")

@property
def model_class(self) -> Optional[str]:
return self.metadata.get("modelClass")
48 changes: 42 additions & 6 deletions google/cloud/aiplatform/metadata/schema/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,12 @@ class PredictSchemata:
prediction_schema_uri: str

def to_dict(self):
"""ML metadata schema dictionary representation of this DataClass"""
"""ML metadata schema dictionary representation of this DataClass.
Returns:
A dictionary that represents the PredictSchemata class.
"""
results = {}
results["instanceSchemaUri"] = self.instance_schema_uri
results["parametersSchemaUri"] = self.parameters_schema_uri
Expand All @@ -62,6 +67,7 @@ def to_dict(self):
@dataclass
class ContainerSpec:
"""Container configuration for the model.
Args:
image_uri (str):
Required. URI of the Docker image to be used as the custom
Expand Down Expand Up @@ -124,7 +130,12 @@ class ContainerSpec:
health_route: Optional[str] = None

def to_dict(self):
"""ML metadata schema dictionary representation of this DataClass"""
"""ML metadata schema dictionary representation of this DataClass.
Returns:
A dictionary that represents the ContainerSpec class.
"""
results = {}
results["imageUri"] = self.image_uri
if self.command:
Expand All @@ -146,6 +157,7 @@ def to_dict(self):
@dataclass
class AnnotationSpec:
"""A class that represents the annotation spec of a Confusion Matrix.
Args:
display_name (str):
Optional. Display name for a column of a confusion matrix.
Expand All @@ -157,7 +169,12 @@ class AnnotationSpec:
id: Optional[str] = None

def to_dict(self):
"""ML metadata schema dictionary representation of this DataClass"""
"""ML metadata schema dictionary representation of this DataClass.
Returns:
A dictionary that represents the AnnotationSpec class.
"""
results = {}
if self.display_name:
results["displayName"] = self.display_name
Expand All @@ -170,6 +187,7 @@ def to_dict(self):
@dataclass
class ConfusionMatrix:
"""A class that represents a Confusion Matrix.
Args:
matrix (List[List[int]]):
Required. A 2D array of integers that represets the values for the confusion matrix.
Expand All @@ -181,10 +199,23 @@ class ConfusionMatrix:
annotation_specs: Optional[List[AnnotationSpec]] = None

def to_dict(self):
## Todo: add a validation to check 'matrix' and 'annotation_specs' have the same length
"""ML metadata schema dictionary representation of this DataClass"""
"""ML metadata schema dictionary representation of this DataClass.
Returns:
A dictionary that represents the ConfusionMatrix class.
Raises:
ValueError: if annotation_specs and matrix have different length.
"""
results = {}
if self.annotation_specs:
if len(self.annotation_specs) != len(self.matrix):
raise ValueError(
"Length of annotation_specs and matrix must be the same. "
"Got lengths {} and {} respectively.".format(
len(self.annotation_specs), len(self.matrix)
)
)
results["annotationSpecs"] = [
annotation_spec.to_dict() for annotation_spec in self.annotation_specs
]
Expand Down Expand Up @@ -255,7 +286,12 @@ class ConfidenceMetric:
confusion_matrix: Optional[ConfusionMatrix] = None

def to_dict(self):
"""ML metadata schema dictionary representation of this DataClass"""
"""ML metadata schema dictionary representation of this DataClass.
Returns:
A dictionary that represents the ConfidenceMetric class.
"""
results = {}
results["confidenceThreshold"] = self.confidence_threshold
if self.recall is not None:
Expand Down
Loading

0 comments on commit 94b2f29

Please sign in to comment.