Skip to content

Commit

Permalink
feat: Add endpoind_id arg to Endpoint#create
Browse files Browse the repository at this point in the history
  • Loading branch information
mai-nakagawa committed Apr 19, 2022
1 parent c1e899d commit a3ebf9d
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 0 deletions.
26 changes: 26 additions & 0 deletions google/cloud/aiplatform/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def network(self) -> Optional[str]:
def create(
cls,
display_name: Optional[str] = None,
endpoint_id: Optional[str] = None,
description: Optional[str] = None,
labels: Optional[Dict[str, str]] = None,
metadata: Optional[Sequence[Tuple[str, str]]] = (),
Expand All @@ -221,6 +222,17 @@ def create(
location (str):
Required. Location to retrieve endpoint from. If not set, location
set in aiplatform.init will be used.
endpoint_id (str):
Optional. The ID to use for endpoint, which will become
the final component of the endpoint resource name. If
not provided, Vertex AI will generate a value for this
ID.
This value should be 1-10 characters, and valid
characters are /[0-9]/. When using HTTP/JSON, this field
is populated based on a query string argument, such as
``?endpoint_id=12345``. This is the fallback for fields
that are not included in either the URI or the body.
description (str):
Optional. The description of the Endpoint.
labels (Dict[str, str]):
Expand Down Expand Up @@ -276,6 +288,7 @@ def create(
return cls._create(
api_client=api_client,
display_name=display_name,
endpoint_id=endpoint_id,
project=project,
location=location,
description=description,
Expand All @@ -297,6 +310,7 @@ def _create(
display_name: str,
project: str,
location: str,
endpoint_id: Optional[str] = None,
description: Optional[str] = None,
labels: Optional[Dict[str, str]] = None,
metadata: Optional[Sequence[Tuple[str, str]]] = (),
Expand All @@ -321,6 +335,17 @@ def _create(
location (str):
Required. Location to retrieve endpoint from. If not set, location
set in aiplatform.init will be used.
endpoint_id (str):
Optional. The ID to use for endpoint, which will become
the final component of the endpoint resource name. If
not provided, Vertex AI will generate a value for this
ID.
This value should be 1-10 characters, and valid
characters are /[0-9]/. When using HTTP/JSON, this field
is populated based on a query string argument, such as
``?endpoint_id=12345``. This is the fallback for fields
that are not included in either the URI or the body.
description (str):
Optional. The description of the Endpoint.
labels (Dict[str, str]):
Expand Down Expand Up @@ -368,6 +393,7 @@ def _create(
operation_future = api_client.create_endpoint(
parent=parent,
endpoint=gapic_endpoint,
endpoint_id=endpoint_id,
metadata=metadata,
timeout=create_request_timeout,
)
Expand Down
30 changes: 30 additions & 0 deletions tests/unit/aiplatform/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,7 @@ def test_init_aiplatform_with_encryption_key_name_and_create_endpoint(
create_endpoint_mock.assert_called_once_with(
parent=_TEST_PARENT,
endpoint=expected_endpoint,
endpoint_id=None,
metadata=(),
timeout=None,
)
Expand All @@ -573,13 +574,39 @@ def test_create(self, create_endpoint_mock, sync):
create_endpoint_mock.assert_called_once_with(
parent=_TEST_PARENT,
endpoint=expected_endpoint,
endpoint_id=None,
metadata=(),
timeout=None,
)

expected_endpoint.name = _TEST_ENDPOINT_NAME
assert my_endpoint._gca_resource == expected_endpoint

@pytest.mark.usefixtures("get_endpoint_mock")
@pytest.mark.parametrize("sync", [True, False])
def test_create_with_endpoint_id(self, create_endpoint_mock, sync):
my_endpoint = models.Endpoint.create(
display_name=_TEST_DISPLAY_NAME,
endpoint_id=_TEST_ID,
description=_TEST_DESCRIPTION,
sync=sync,
create_request_timeout=None,
)
if not sync:
my_endpoint.wait()

expected_endpoint = gca_endpoint.Endpoint(
display_name=_TEST_DISPLAY_NAME,
description=_TEST_DESCRIPTION,
)
create_endpoint_mock.assert_called_once_with(
parent=_TEST_PARENT,
endpoint=expected_endpoint,
endpoint_id=_TEST_ID,
metadata=(),
timeout=None,
)

@pytest.mark.usefixtures("get_endpoint_mock")
@pytest.mark.parametrize("sync", [True, False])
def test_create_with_timeout(self, create_endpoint_mock, sync):
Expand All @@ -599,6 +626,7 @@ def test_create_with_timeout(self, create_endpoint_mock, sync):
create_endpoint_mock.assert_called_once_with(
parent=_TEST_PARENT,
endpoint=expected_endpoint,
endpoint_id=None,
metadata=(),
timeout=180.0,
)
Expand Down Expand Up @@ -642,6 +670,7 @@ def test_create_with_description(self, create_endpoint_mock, sync):
create_endpoint_mock.assert_called_once_with(
parent=_TEST_PARENT,
endpoint=expected_endpoint,
endpoint_id=None,
metadata=(),
timeout=None,
)
Expand All @@ -665,6 +694,7 @@ def test_create_with_labels(self, create_endpoint_mock, sync):
create_endpoint_mock.assert_called_once_with(
parent=_TEST_PARENT,
endpoint=expected_endpoint,
endpoint_id=None,
metadata=(),
timeout=None,
)
Expand Down

0 comments on commit a3ebf9d

Please sign in to comment.