Skip to content

Commit

Permalink
feat: Lazy load Endpoint class (#655)
Browse files Browse the repository at this point in the history
* feat: Lazy load Endpoint class

* Address reviewer comments

* Instance attr, add to E2E test

* Replace _skipped_getter_call instance attr with method

* Update tests, handle sync=False case

* 🦉 Updates from OwlBot

See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md

* Add explicit project to GCS client in E2E test

* Simplifying implementation further

Co-authored-by: Owl Bot <gcf-owl-bot[bot]@users.noreply.github.com>
  • Loading branch information
vinnysenthil and gcf-owl-bot[bot] authored Oct 7, 2021
1 parent 7cb6976 commit c795c6f
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 45 deletions.
43 changes: 39 additions & 4 deletions google/cloud/aiplatform/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,43 @@ def __init__(
credentials=credentials,
resource_name=endpoint_name,
)
self._gca_resource = self._get_gca_resource(resource_name=endpoint_name)

endpoint_name = utils.full_resource_name(
resource_name=endpoint_name,
resource_noun="endpoints",
project=project,
location=location,
)

# Lazy load the Endpoint gca_resource until needed
self._gca_resource = gca_endpoint_compat.Endpoint(name=endpoint_name)

self._prediction_client = self._instantiate_prediction_client(
location=self.location, credentials=credentials,
)

def _skipped_getter_call(self) -> bool:
"""Check if GAPIC resource was populated by call to get/list API methods
Returns False if `_gca_resource` is None or fully populated. Returns True
if `_gca_resource` is partially populated
"""
return self._gca_resource and not self._gca_resource.create_time

def _sync_gca_resource_if_skipped(self) -> None:
"""Sync GAPIC service representation of Endpoint class resource only if
get_endpoint() was never called."""
if self._skipped_getter_call():
self._gca_resource = self._get_gca_resource(
resource_name=self._gca_resource.name
)

def _assert_gca_resource_is_available(self) -> None:
"""Ensures Endpoint getter was called at least once before
asserting on gca_resource's availability."""
super()._assert_gca_resource_is_available()
self._sync_gca_resource_if_skipped()

@property
def traffic_split(self) -> Dict[str, int]:
"""A map from a DeployedModel's ID to the percentage of this Endpoint's
Expand Down Expand Up @@ -315,8 +346,8 @@ def _create(

_LOGGER.log_create_complete(cls, created_endpoint, "endpoint")

return cls(
endpoint_name=created_endpoint.name,
return cls._construct_sdk_resource_from_gapic(
gapic_resource=created_endpoint,
project=project,
location=location,
credentials=credentials,
Expand Down Expand Up @@ -622,6 +653,7 @@ def deploy(
will be executed in concurrent Future and any downstream object will
be immediately returned and synced when the Future has completed.
"""
self._sync_gca_resource_if_skipped()

self._validate_deploy_args(
min_replica_count,
Expand Down Expand Up @@ -967,6 +999,8 @@ def undeploy(
Optional. Strings which should be sent along with the request as
metadata.
"""
self._sync_gca_resource_if_skipped()

if traffic_split is not None:
if deployed_model_id in traffic_split and traffic_split[deployed_model_id]:
raise ValueError("Model being undeployed should have 0 traffic.")
Expand Down Expand Up @@ -1011,6 +1045,7 @@ def _undeploy(
Optional. Strings which should be sent along with the request as
metadata.
"""
self._sync_gca_resource_if_skipped()
current_traffic_split = traffic_split or dict(self._gca_resource.traffic_split)

if deployed_model_id in current_traffic_split:
Expand Down Expand Up @@ -1095,7 +1130,7 @@ def predict(self, instances: List, parameters: Optional[Dict] = None) -> Predict
self.wait()

prediction_response = self._prediction_client.predict(
endpoint=self.resource_name, instances=instances, parameters=parameters
endpoint=self._gca_resource.name, instances=instances, parameters=parameters
)

return Prediction(
Expand Down
43 changes: 17 additions & 26 deletions tests/system/aiplatform/test_e2e_tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,16 @@
_LOCAL_TRAINING_SCRIPT_PATH = os.path.join(
_DIR_NAME, "test_resources/california_housing_training_script.py"
)
_INSTANCE = {
"longitude": -124.35,
"latitude": 40.54,
"housing_median_age": 52.0,
"total_rooms": 1820.0,
"total_bedrooms": 300.0,
"population": 806,
"households": 270.0,
"median_income": 3.014700,
}


@pytest.mark.usefixtures("prepare_staging_bucket", "delete_staging_bucket", "teardown")
Expand Down Expand Up @@ -136,39 +146,20 @@ def test_end_to_end_tabular(self, shared_state):
# Send online prediction with same instance to both deployed models
# This sample is taken from an observation where median_house_value = 94600
custom_endpoint.wait()
custom_prediction = custom_endpoint.predict(
[
{
"longitude": -124.35,
"latitude": 40.54,
"housing_median_age": 52.0,
"total_rooms": 1820.0,
"total_bedrooms": 300.0,
"population": 806,
"households": 270.0,
"median_income": 3.014700,
},
]
)
custom_prediction = custom_endpoint.predict([_INSTANCE])

custom_batch_prediction_job.wait()

automl_endpoint.wait()
automl_prediction = automl_endpoint.predict(
[
{
"longitude": "-124.35",
"latitude": "40.54",
"housing_median_age": "52.0",
"total_rooms": "1820.0",
"total_bedrooms": "300.0",
"population": "806",
"households": "270.0",
"median_income": "3.014700",
},
]
[{k: str(v) for k, v in _INSTANCE.items()}] # Cast int values to strings
)

# Test lazy loading of Endpoint, check getter was never called after predict()
custom_endpoint = aiplatform.Endpoint(custom_endpoint.resource_name)
custom_endpoint.predict([_INSTANCE])
assert custom_endpoint._skipped_getter_call()

assert (
custom_job.state
== gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
Expand Down
6 changes: 6 additions & 0 deletions tests/unit/aiplatform/test_end_to_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,12 @@ def test_dataset_create_to_model_predict(
assert endpoint_deploy_return is None

if not sync:
# Accessing attribute in Endpoint that has not been created raises informatively
with pytest.raises(
RuntimeError, match=r"Endpoint resource has not been created."
):
my_endpoint.network

my_endpoint.wait()
created_endpoint.wait()

Expand Down
67 changes: 52 additions & 15 deletions tests/unit/aiplatform/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,15 @@ def get_endpoint_mock():
yield get_endpoint_mock


@pytest.fixture
def get_empty_endpoint_mock():
with mock.patch.object(
endpoint_service_client.EndpointServiceClient, "get_endpoint"
) as get_endpoint_mock:
get_endpoint_mock.return_value = gca_endpoint.Endpoint(name=_TEST_ENDPOINT_NAME)
yield get_endpoint_mock


@pytest.fixture
def get_endpoint_alt_location_mock():
with mock.patch.object(
Expand Down Expand Up @@ -213,7 +222,9 @@ def create_endpoint_mock():
) as create_endpoint_mock:
create_endpoint_lro_mock = mock.Mock(ga_operation.Operation)
create_endpoint_lro_mock.result.return_value = gca_endpoint.Endpoint(
name=_TEST_ENDPOINT_NAME, display_name=_TEST_DISPLAY_NAME
name=_TEST_ENDPOINT_NAME,
display_name=_TEST_DISPLAY_NAME,
encryption_spec=_TEST_ENCRYPTION_SPEC,
)
create_endpoint_mock.return_value = create_endpoint_lro_mock
yield create_endpoint_mock
Expand Down Expand Up @@ -378,19 +389,35 @@ def test_constructor(self, create_endpoint_client_mock):
]
)

def test_constructor_with_endpoint_id(self, get_endpoint_mock):
models.Endpoint(_TEST_ID)
get_endpoint_mock.assert_called_with(name=_TEST_ENDPOINT_NAME)
def test_lazy_constructor_with_endpoint_id(self, get_endpoint_mock):
ep = models.Endpoint(_TEST_ID)
assert ep._gca_resource.name == _TEST_ENDPOINT_NAME
assert ep._skipped_getter_call()
assert not get_endpoint_mock.called

def test_constructor_with_endpoint_name(self, get_endpoint_mock):
models.Endpoint(_TEST_ENDPOINT_NAME)
def test_lazy_constructor_with_endpoint_name(self, get_endpoint_mock):
ep = models.Endpoint(_TEST_ENDPOINT_NAME)
assert ep._gca_resource.name == _TEST_ENDPOINT_NAME
assert ep._skipped_getter_call()
assert not get_endpoint_mock.called

def test_lazy_constructor_calls_get_on_property_access(self, get_endpoint_mock):
ep = models.Endpoint(_TEST_ENDPOINT_NAME)
assert ep._gca_resource.name == _TEST_ENDPOINT_NAME
assert ep._skipped_getter_call()
assert not get_endpoint_mock.called

ep.display_name # Retrieve a property that requires a call to Endpoint getter
get_endpoint_mock.assert_called_with(name=_TEST_ENDPOINT_NAME)

def test_constructor_with_custom_project(self, get_endpoint_mock):
models.Endpoint(endpoint_name=_TEST_ID, project=_TEST_PROJECT_2)
def test_lazy_constructor_with_custom_project(self, get_endpoint_mock):
ep = models.Endpoint(endpoint_name=_TEST_ID, project=_TEST_PROJECT_2)
test_endpoint_resource_name = endpoint_service_client.EndpointServiceClient.endpoint_path(
_TEST_PROJECT_2, _TEST_LOCATION, _TEST_ID
)
assert not get_endpoint_mock.called

ep.name # Retrieve a property that requires a call to Endpoint getter
get_endpoint_mock.assert_called_with(name=test_endpoint_resource_name)

@pytest.mark.usefixtures("get_endpoint_mock")
Expand All @@ -406,11 +433,19 @@ def test_constructor_with_conflicting_location(self):
regexp=r"is provided, but different from the resource location"
)

def test_constructor_with_custom_location(self, get_endpoint_alt_location_mock):
models.Endpoint(endpoint_name=_TEST_ID, location=_TEST_LOCATION_2)
def test_lazy_constructor_with_custom_location(
self, get_endpoint_alt_location_mock
):
ep = models.Endpoint(endpoint_name=_TEST_ID, location=_TEST_LOCATION_2)
test_endpoint_resource_name = endpoint_service_client.EndpointServiceClient.endpoint_path(
_TEST_PROJECT, _TEST_LOCATION_2, _TEST_ID
)

# Get Endpoint not called due to lazy loading
assert not get_endpoint_alt_location_mock.called

ep.network # Accessing a property that requires calling getter

get_endpoint_alt_location_mock.assert_called_with(
name=test_endpoint_resource_name
)
Expand Down Expand Up @@ -481,15 +516,17 @@ def test_create(self, create_endpoint_mock, sync):
)

expected_endpoint.name = _TEST_ENDPOINT_NAME
assert my_endpoint.gca_resource == expected_endpoint
assert my_endpoint.network is None
assert my_endpoint._gca_resource == expected_endpoint

@pytest.mark.usefixtures("get_endpoint_mock")
@pytest.mark.usefixtures("get_empty_endpoint_mock")
def test_accessing_properties_with_no_resource_raises(self,):
"""Ensure a descriptive RuntimeError is raised when the
GAPIC object has not been populated"""

my_endpoint = aiplatform.Endpoint(_TEST_ENDPOINT_NAME)

my_endpoint._gca_resource = None
# Create a gca_resource without `name` being populated
my_endpoint._gca_resource = gca_endpoint.Endpoint(create_time=datetime.now())

with pytest.raises(RuntimeError) as e:
my_endpoint.gca_resource
Expand Down Expand Up @@ -909,7 +946,7 @@ def test_undeploy(self, undeploy_model_mock, sync):
traffic_split={"model1": 100},
)
test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
assert dict(test_endpoint._gca_resource.traffic_split) == {"model1": 100}
assert dict(test_endpoint.traffic_split) == {"model1": 100}
test_endpoint.undeploy("model1", sync=sync)
if not sync:
test_endpoint.wait()
Expand Down

0 comments on commit c795c6f

Please sign in to comment.