diff --git a/README.md b/README.md index 30b7c75a..c5116430 100644 --- a/README.md +++ b/README.md @@ -94,6 +94,49 @@ we recommend using [python-dotenv](https://pypi.org/project/python-dotenv/) to add `DIGITALOCEAN_ACCESS_TOKEN="My Access Token"`, `GRADIENT_MODEL_ACCESS_KEY="My Model Access Key"` to your `.env` file so that your keys are not stored in source control. +## Knowledge Base Database Polling + +When creating a Knowledge Base, the database deployment can take several minutes. The `wait_for_database()` helper function simplifies polling for the database status: + +```python +from gradient import Gradient +from gradient.resources.knowledge_bases import KnowledgeBaseDatabaseError +from gradient._exceptions import APITimeoutError + +client = Gradient() + +# Create a knowledge base +kb_response = client.knowledge_bases.create( + name="My Knowledge Base", + region="nyc1", + embedding_model_uuid="your-embedding-model-uuid", +) + +kb_uuid = kb_response.knowledge_base.uuid + +try: + # Wait for the database to be ready (default: 10 minute timeout, 5 second poll interval) + result = client.knowledge_bases.wait_for_database(kb_uuid) + print(f"Database status: {result.database_status}") # "ONLINE" + + # Custom timeout and poll interval + result = client.knowledge_bases.wait_for_database( + kb_uuid, + timeout=900.0, # 15 minutes + poll_interval=10.0 # Check every 10 seconds + ) + +except KnowledgeBaseDatabaseError as e: + # Database entered a failed state (DECOMMISSIONED or UNHEALTHY) + print(f"Database failed: {e}") + +except APITimeoutError: + # Database did not become ready within the timeout period + print("Timeout: Database did not become ready in time") +``` + +The helper handles all state transitions and will raise appropriate exceptions for failed states or timeouts. + ## Async usage Simply import `AsyncGradient` instead of `Gradient` and use `await` with each API call: diff --git a/pyproject.toml b/pyproject.toml index dade45c8..13bc2865 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -246,5 +246,5 @@ known-first-party = ["gradient", "tests"] [tool.ruff.lint.per-file-ignores] "bin/**.py" = ["T201", "T203"] "scripts/**.py" = ["T201", "T203"] -"tests/**.py" = ["T201", "T203"] +"tests/**.py" = ["T201", "T203", "ARG001"] "examples/**.py" = ["T201", "T203"] diff --git a/src/gradient/resources/knowledge_bases/__init__.py b/src/gradient/resources/knowledge_bases/__init__.py index 80d04328..90ebea00 100644 --- a/src/gradient/resources/knowledge_bases/__init__.py +++ b/src/gradient/resources/knowledge_bases/__init__.py @@ -19,6 +19,7 @@ from .knowledge_bases import ( KnowledgeBasesResource, AsyncKnowledgeBasesResource, + KnowledgeBaseDatabaseError, KnowledgeBasesResourceWithRawResponse, AsyncKnowledgeBasesResourceWithRawResponse, KnowledgeBasesResourceWithStreamingResponse, @@ -40,6 +41,7 @@ "AsyncIndexingJobsResourceWithStreamingResponse", "KnowledgeBasesResource", "AsyncKnowledgeBasesResource", + "KnowledgeBaseDatabaseError", "KnowledgeBasesResourceWithRawResponse", "AsyncKnowledgeBasesResourceWithRawResponse", "KnowledgeBasesResourceWithStreamingResponse", diff --git a/src/gradient/resources/knowledge_bases/knowledge_bases.py b/src/gradient/resources/knowledge_bases/knowledge_bases.py index 00fa0659..d92622a8 100644 --- a/src/gradient/resources/knowledge_bases/knowledge_bases.py +++ b/src/gradient/resources/knowledge_bases/knowledge_bases.py @@ -2,6 +2,8 @@ from __future__ import annotations +import time +import asyncio from typing import Iterable import httpx @@ -25,6 +27,7 @@ DataSourcesResourceWithStreamingResponse, AsyncDataSourcesResourceWithStreamingResponse, ) +from ..._exceptions import APITimeoutError from .indexing_jobs import ( IndexingJobsResource, AsyncIndexingJobsResource, @@ -40,7 +43,13 @@ from ...types.knowledge_base_update_response import KnowledgeBaseUpdateResponse from ...types.knowledge_base_retrieve_response import KnowledgeBaseRetrieveResponse -__all__ = ["KnowledgeBasesResource", "AsyncKnowledgeBasesResource"] +__all__ = ["KnowledgeBasesResource", "AsyncKnowledgeBasesResource", "KnowledgeBaseDatabaseError"] + + +class KnowledgeBaseDatabaseError(Exception): + """Raised when a knowledge base database enters a failed state.""" + + pass class KnowledgeBasesResource(SyncAPIResource): @@ -330,6 +339,85 @@ def delete( cast_to=KnowledgeBaseDeleteResponse, ) + def wait_for_database( + self, + uuid: str, + *, + timeout: float = 600.0, + poll_interval: float = 5.0, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + ) -> KnowledgeBaseRetrieveResponse: + """ + Poll the knowledge base until the database status is ONLINE or a failed state is reached. + + This helper function repeatedly calls retrieve() to check the database_status field. + It will wait for the database to become ONLINE, or raise an exception if it enters + a failed state (DECOMMISSIONED or UNHEALTHY) or if the timeout is exceeded. + + Args: + uuid: The knowledge base UUID to poll + + timeout: Maximum time to wait in seconds (default: 600 seconds / 10 minutes) + + poll_interval: Time to wait between polls in seconds (default: 5 seconds) + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + Returns: + The final KnowledgeBaseRetrieveResponse when the database status is ONLINE + + Raises: + KnowledgeBaseDatabaseError: If the database enters a failed state (DECOMMISSIONED, UNHEALTHY) + + APITimeoutError: If the timeout is exceeded before the database becomes ONLINE + """ + if not uuid: + raise ValueError(f"Expected a non-empty value for `uuid` but received {uuid!r}") + + start_time = time.time() + failed_states = {"DECOMMISSIONED", "UNHEALTHY"} + + while True: + elapsed = time.time() - start_time + if elapsed >= timeout: + raise APITimeoutError( + request=httpx.Request( + method="GET", + url=f"https://api.digitalocean.com/v2/gen-ai/knowledge_bases/{uuid}", + ) + ) + + response = self.retrieve( + uuid, + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + ) + + status = response.database_status + + if status == "ONLINE": + return response + + if status in failed_states: + raise KnowledgeBaseDatabaseError( + f"Knowledge base database entered failed state: {status}" + ) + + # Sleep before next poll, but don't exceed timeout + remaining_time = timeout - elapsed + sleep_time = min(poll_interval, remaining_time) + if sleep_time > 0: + time.sleep(sleep_time) + class AsyncKnowledgeBasesResource(AsyncAPIResource): @cached_property @@ -618,6 +706,85 @@ async def delete( cast_to=KnowledgeBaseDeleteResponse, ) + async def wait_for_database( + self, + uuid: str, + *, + timeout: float = 600.0, + poll_interval: float = 5.0, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + ) -> KnowledgeBaseRetrieveResponse: + """ + Poll the knowledge base until the database status is ONLINE or a failed state is reached. + + This helper function repeatedly calls retrieve() to check the database_status field. + It will wait for the database to become ONLINE, or raise an exception if it enters + a failed state (DECOMMISSIONED or UNHEALTHY) or if the timeout is exceeded. + + Args: + uuid: The knowledge base UUID to poll + + timeout: Maximum time to wait in seconds (default: 600 seconds / 10 minutes) + + poll_interval: Time to wait between polls in seconds (default: 5 seconds) + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + Returns: + The final KnowledgeBaseRetrieveResponse when the database status is ONLINE + + Raises: + KnowledgeBaseDatabaseError: If the database enters a failed state (DECOMMISSIONED, UNHEALTHY) + + APITimeoutError: If the timeout is exceeded before the database becomes ONLINE + """ + if not uuid: + raise ValueError(f"Expected a non-empty value for `uuid` but received {uuid!r}") + + start_time = time.time() + failed_states = {"DECOMMISSIONED", "UNHEALTHY"} + + while True: + elapsed = time.time() - start_time + if elapsed >= timeout: + raise APITimeoutError( + request=httpx.Request( + method="GET", + url=f"https://api.digitalocean.com/v2/gen-ai/knowledge_bases/{uuid}", + ) + ) + + response = await self.retrieve( + uuid, + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + ) + + status = response.database_status + + if status == "ONLINE": + return response + + if status in failed_states: + raise KnowledgeBaseDatabaseError( + f"Knowledge base database entered failed state: {status}" + ) + + # Sleep before next poll, but don't exceed timeout + remaining_time = timeout - elapsed + sleep_time = min(poll_interval, remaining_time) + if sleep_time > 0: + await asyncio.sleep(sleep_time) + class KnowledgeBasesResourceWithRawResponse: def __init__(self, knowledge_bases: KnowledgeBasesResource) -> None: @@ -638,6 +805,9 @@ def __init__(self, knowledge_bases: KnowledgeBasesResource) -> None: self.delete = to_raw_response_wrapper( knowledge_bases.delete, ) + self.wait_for_database = to_raw_response_wrapper( + knowledge_bases.wait_for_database, + ) @cached_property def data_sources(self) -> DataSourcesResourceWithRawResponse: @@ -667,6 +837,9 @@ def __init__(self, knowledge_bases: AsyncKnowledgeBasesResource) -> None: self.delete = async_to_raw_response_wrapper( knowledge_bases.delete, ) + self.wait_for_database = async_to_raw_response_wrapper( + knowledge_bases.wait_for_database, + ) @cached_property def data_sources(self) -> AsyncDataSourcesResourceWithRawResponse: @@ -696,6 +869,9 @@ def __init__(self, knowledge_bases: KnowledgeBasesResource) -> None: self.delete = to_streamed_response_wrapper( knowledge_bases.delete, ) + self.wait_for_database = to_streamed_response_wrapper( + knowledge_bases.wait_for_database, + ) @cached_property def data_sources(self) -> DataSourcesResourceWithStreamingResponse: @@ -725,6 +901,9 @@ def __init__(self, knowledge_bases: AsyncKnowledgeBasesResource) -> None: self.delete = async_to_streamed_response_wrapper( knowledge_bases.delete, ) + self.wait_for_database = async_to_streamed_response_wrapper( + knowledge_bases.wait_for_database, + ) @cached_property def data_sources(self) -> AsyncDataSourcesResourceWithStreamingResponse: diff --git a/tests/api_resources/test_knowledge_bases.py b/tests/api_resources/test_knowledge_bases.py index 62965775..16773a0b 100644 --- a/tests/api_resources/test_knowledge_bases.py +++ b/tests/api_resources/test_knowledge_bases.py @@ -275,6 +275,102 @@ def test_path_params_delete(self, client: Gradient) -> None: "", ) + @parametrize + def test_method_wait_for_database_success(self, client: Gradient) -> None: + """Test wait_for_database with successful database status transition.""" + from unittest.mock import Mock + + call_count = [0] + + def mock_retrieve(uuid, **kwargs): + call_count[0] += 1 + response = Mock() + # Simulate CREATING -> ONLINE transition + response.database_status = "CREATING" if call_count[0] == 1 else "ONLINE" + return response + + client.knowledge_bases.retrieve = mock_retrieve + + result = client.knowledge_bases.wait_for_database( + "test-uuid", + timeout=10.0, + poll_interval=0.1, + ) + + assert result.database_status == "ONLINE" + assert call_count[0] == 2 + + @parametrize + def test_method_wait_for_database_failed_state(self, client: Gradient) -> None: + """Test wait_for_database with failed database status.""" + from unittest.mock import Mock + + from gradient.resources.knowledge_bases import KnowledgeBaseDatabaseError + + def mock_retrieve(uuid, **kwargs): + response = Mock() + response.database_status = "UNHEALTHY" + return response + + client.knowledge_bases.retrieve = mock_retrieve + + with pytest.raises(KnowledgeBaseDatabaseError, match="UNHEALTHY"): + client.knowledge_bases.wait_for_database( + "test-uuid", + timeout=10.0, + poll_interval=0.1, + ) + + @parametrize + def test_method_wait_for_database_timeout(self, client: Gradient) -> None: + """Test wait_for_database with timeout.""" + from unittest.mock import Mock + + from gradient._exceptions import APITimeoutError + + def mock_retrieve(uuid, **kwargs): + response = Mock() + response.database_status = "CREATING" + return response + + client.knowledge_bases.retrieve = mock_retrieve + + with pytest.raises(APITimeoutError): + client.knowledge_bases.wait_for_database( + "test-uuid", + timeout=0.3, + poll_interval=0.1, + ) + + @parametrize + def test_method_wait_for_database_decommissioned(self, client: Gradient) -> None: + """Test wait_for_database with DECOMMISSIONED status.""" + from unittest.mock import Mock + + from gradient.resources.knowledge_bases import KnowledgeBaseDatabaseError + + def mock_retrieve(uuid, **kwargs): + response = Mock() + response.database_status = "DECOMMISSIONED" + return response + + client.knowledge_bases.retrieve = mock_retrieve + + with pytest.raises(KnowledgeBaseDatabaseError, match="DECOMMISSIONED"): + client.knowledge_bases.wait_for_database( + "test-uuid", + timeout=10.0, + poll_interval=0.1, + ) + + @parametrize + def test_path_params_wait_for_database(self, client: Gradient) -> None: + """Test wait_for_database validates uuid parameter.""" + with pytest.raises(ValueError, match=r"Expected a non-empty value for `uuid` but received ''"): + client.knowledge_bases.wait_for_database( + "", + ) + class TestAsyncKnowledgeBases: parametrize = pytest.mark.parametrize( @@ -532,3 +628,99 @@ async def test_path_params_delete(self, async_client: AsyncGradient) -> None: await async_client.knowledge_bases.with_raw_response.delete( "", ) + + @parametrize + async def test_method_wait_for_database_success(self, async_client: AsyncGradient) -> None: + """Test async wait_for_database with successful database status transition.""" + from unittest.mock import Mock + + call_count = [0] + + async def mock_retrieve(uuid, **kwargs): + call_count[0] += 1 + response = Mock() + # Simulate CREATING -> ONLINE transition + response.database_status = "CREATING" if call_count[0] == 1 else "ONLINE" + return response + + async_client.knowledge_bases.retrieve = mock_retrieve + + result = await async_client.knowledge_bases.wait_for_database( + "test-uuid", + timeout=10.0, + poll_interval=0.1, + ) + + assert result.database_status == "ONLINE" + assert call_count[0] == 2 + + @parametrize + async def test_method_wait_for_database_failed_state(self, async_client: AsyncGradient) -> None: + """Test async wait_for_database with failed database status.""" + from unittest.mock import Mock + + from gradient.resources.knowledge_bases import KnowledgeBaseDatabaseError + + async def mock_retrieve(uuid, **kwargs): + response = Mock() + response.database_status = "UNHEALTHY" + return response + + async_client.knowledge_bases.retrieve = mock_retrieve + + with pytest.raises(KnowledgeBaseDatabaseError, match="UNHEALTHY"): + await async_client.knowledge_bases.wait_for_database( + "test-uuid", + timeout=10.0, + poll_interval=0.1, + ) + + @parametrize + async def test_method_wait_for_database_timeout(self, async_client: AsyncGradient) -> None: + """Test async wait_for_database with timeout.""" + from unittest.mock import Mock + + from gradient._exceptions import APITimeoutError + + async def mock_retrieve(uuid, **kwargs): + response = Mock() + response.database_status = "CREATING" + return response + + async_client.knowledge_bases.retrieve = mock_retrieve + + with pytest.raises(APITimeoutError): + await async_client.knowledge_bases.wait_for_database( + "test-uuid", + timeout=0.3, + poll_interval=0.1, + ) + + @parametrize + async def test_method_wait_for_database_decommissioned(self, async_client: AsyncGradient) -> None: + """Test async wait_for_database with DECOMMISSIONED status.""" + from unittest.mock import Mock + + from gradient.resources.knowledge_bases import KnowledgeBaseDatabaseError + + async def mock_retrieve(uuid, **kwargs): + response = Mock() + response.database_status = "DECOMMISSIONED" + return response + + async_client.knowledge_bases.retrieve = mock_retrieve + + with pytest.raises(KnowledgeBaseDatabaseError, match="DECOMMISSIONED"): + await async_client.knowledge_bases.wait_for_database( + "test-uuid", + timeout=10.0, + poll_interval=0.1, + ) + + @parametrize + async def test_path_params_wait_for_database(self, async_client: AsyncGradient) -> None: + """Test async wait_for_database validates uuid parameter.""" + with pytest.raises(ValueError, match=r"Expected a non-empty value for `uuid` but received ''"): + await async_client.knowledge_bases.wait_for_database( + "", + )