Skip to content

Commit

Permalink
fix: update Model.list_model_evaluations and get_model_evaluation to …
Browse files Browse the repository at this point in the history
…use the provided version (#1616)

* fix: add model versioning support to Model.get_model_evaluation and list_model_evaluations

* linting fix

* update list_model_evaluations to use instantiated model version

* update get evaluation method
  • Loading branch information
sararob authored Aug 31, 2022
1 parent 484e416 commit 8fb836b
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 6 deletions.
15 changes: 10 additions & 5 deletions google/cloud/aiplatform/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4571,10 +4571,14 @@ def list_model_evaluations(
self,
) -> List["model_evaluation.ModelEvaluation"]:
"""List all Model Evaluation resources associated with this model.
If this Model resource was instantiated with a version, the Model
Evaluation resources for that version will be returned. If no version
was provided when the Model resource was instantiated, Model Evaluation
resources will be returned for the default version.
Example Usage:
my_model = Model(
model_name="projects/123/locations/us-central1/models/456"
model_name="projects/123/locations/us-central1/models/456@1"
)
my_evaluations = my_model.list_model_evaluations()
Expand All @@ -4584,10 +4588,8 @@ def list_model_evaluations(
List of ModelEvaluation resources for the model.
"""

self.wait()

return model_evaluation.ModelEvaluation._list(
parent=self.resource_name,
parent=self.versioned_resource_name,
credentials=self.credentials,
)

Expand All @@ -4597,7 +4599,10 @@ def get_model_evaluation(
) -> Optional[model_evaluation.ModelEvaluation]:
"""Returns a ModelEvaluation resource and instantiates its representation.
If no evaluation_id is passed, it will return the first evaluation associated
with this model.
with this model. If the aiplatform.Model resource was instantiated with a
version, this will return a Model Evaluation from that version. If no version
was specified when instantiating the Model resource, this will return an
Evaluation from the default version.
Example usage:
my_model = Model(
Expand Down
39 changes: 38 additions & 1 deletion tests/unit/aiplatform/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2357,7 +2357,7 @@ def test_update(self, update_model_mock, get_model_mock):
model=current_model_proto, update_mask=update_mask
)

def test_get_model_evaluation_with_id(
def test_get_model_evaluation_with_evaluation_id(
self,
mock_model_eval_get,
get_model_mock,
Expand All @@ -2371,6 +2371,26 @@ def test_get_model_evaluation_with_id(
name=_TEST_MODEL_EVAL_RESOURCE_NAME, retry=base._DEFAULT_RETRY
)

def test_get_model_evaluation_with_evaluation_and_instantiated_version(
self,
mock_model_eval_get,
get_model_mock,
list_model_evaluations_mock,
):
test_model = models.Model(
model_name=f"{_TEST_MODEL_RESOURCE_NAME}@{_TEST_VERSION_ID}"
)

test_model.get_model_evaluation(evaluation_id=_TEST_ID)

mock_model_eval_get.assert_called_once_with(
name=_TEST_MODEL_EVAL_RESOURCE_NAME, retry=base._DEFAULT_RETRY
)

list_model_evaluations_mock.assert_called_once_with(
request={"parent": test_model.versioned_resource_name}
)

def test_get_model_evaluation_without_id(
self,
mock_model_eval_get,
Expand Down Expand Up @@ -2402,6 +2422,23 @@ def test_list_model_evaluations(

assert len(eval_list) == len(_TEST_MODEL_EVAL_LIST)

def test_list_model_evaluations_with_version(
self,
get_model_mock,
mock_model_eval_get,
list_model_evaluations_mock,
):

test_model = models.Model(
model_name=f"{_TEST_MODEL_RESOURCE_NAME}@{_TEST_VERSION_ID}"
)

test_model.list_model_evaluations()

list_model_evaluations_mock.assert_called_once_with(
request={"parent": test_model.versioned_resource_name}
)

def test_init_with_version_in_resource_name(self, get_model_with_version):
model = models.Model(
model_name=models.ModelRegistry._get_versioned_name(
Expand Down

0 comments on commit 8fb836b

Please sign in to comment.