Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,49 @@ we recommend using [python-dotenv](https://pypi.org/project/python-dotenv/)
to add `DIGITALOCEAN_ACCESS_TOKEN="My Access Token"`, `GRADIENT_MODEL_ACCESS_KEY="My Model Access Key"` to your `.env` file
so that your keys are not stored in source control.

## Knowledge Base Database Polling

When creating a Knowledge Base, the database deployment can take several minutes. The `wait_for_database()` helper function simplifies polling for the database status:

```python
from gradient import Gradient
from gradient.resources.knowledge_bases import KnowledgeBaseDatabaseError
from gradient._exceptions import APITimeoutError

client = Gradient()

# Create a knowledge base
kb_response = client.knowledge_bases.create(
name="My Knowledge Base",
region="nyc1",
embedding_model_uuid="your-embedding-model-uuid",
)

kb_uuid = kb_response.knowledge_base.uuid

try:
# Wait for the database to be ready (default: 10 minute timeout, 5 second poll interval)
result = client.knowledge_bases.wait_for_database(kb_uuid)
print(f"Database status: {result.database_status}") # "ONLINE"

# Custom timeout and poll interval
result = client.knowledge_bases.wait_for_database(
kb_uuid,
timeout=900.0, # 15 minutes
poll_interval=10.0 # Check every 10 seconds
)

except KnowledgeBaseDatabaseError as e:
# Database entered a failed state (DECOMMISSIONED or UNHEALTHY)
print(f"Database failed: {e}")

except APITimeoutError:
# Database did not become ready within the timeout period
print("Timeout: Database did not become ready in time")
```

The helper handles all state transitions and will raise appropriate exceptions for failed states or timeouts.

## Async usage

Simply import `AsyncGradient` instead of `Gradient` and use `await` with each API call:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -246,5 +246,5 @@ known-first-party = ["gradient", "tests"]
[tool.ruff.lint.per-file-ignores]
"bin/**.py" = ["T201", "T203"]
"scripts/**.py" = ["T201", "T203"]
"tests/**.py" = ["T201", "T203"]
"tests/**.py" = ["T201", "T203", "ARG001"]
"examples/**.py" = ["T201", "T203"]
2 changes: 2 additions & 0 deletions src/gradient/resources/knowledge_bases/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .knowledge_bases import (
KnowledgeBasesResource,
AsyncKnowledgeBasesResource,
KnowledgeBaseDatabaseError,
KnowledgeBasesResourceWithRawResponse,
AsyncKnowledgeBasesResourceWithRawResponse,
KnowledgeBasesResourceWithStreamingResponse,
Expand All @@ -40,6 +41,7 @@
"AsyncIndexingJobsResourceWithStreamingResponse",
"KnowledgeBasesResource",
"AsyncKnowledgeBasesResource",
"KnowledgeBaseDatabaseError",
"KnowledgeBasesResourceWithRawResponse",
"AsyncKnowledgeBasesResourceWithRawResponse",
"KnowledgeBasesResourceWithStreamingResponse",
Expand Down
181 changes: 180 additions & 1 deletion src/gradient/resources/knowledge_bases/knowledge_bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from __future__ import annotations

import time
import asyncio
from typing import Iterable

import httpx
Expand All @@ -25,6 +27,7 @@
DataSourcesResourceWithStreamingResponse,
AsyncDataSourcesResourceWithStreamingResponse,
)
from ..._exceptions import APITimeoutError
from .indexing_jobs import (
IndexingJobsResource,
AsyncIndexingJobsResource,
Expand All @@ -40,7 +43,13 @@
from ...types.knowledge_base_update_response import KnowledgeBaseUpdateResponse
from ...types.knowledge_base_retrieve_response import KnowledgeBaseRetrieveResponse

__all__ = ["KnowledgeBasesResource", "AsyncKnowledgeBasesResource"]
__all__ = ["KnowledgeBasesResource", "AsyncKnowledgeBasesResource", "KnowledgeBaseDatabaseError"]


class KnowledgeBaseDatabaseError(Exception):
"""Raised when a knowledge base database enters a failed state."""

pass


class KnowledgeBasesResource(SyncAPIResource):
Expand Down Expand Up @@ -330,6 +339,85 @@ def delete(
cast_to=KnowledgeBaseDeleteResponse,
)

def wait_for_database(
self,
uuid: str,
*,
timeout: float = 600.0,
poll_interval: float = 5.0,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
) -> KnowledgeBaseRetrieveResponse:
"""
Poll the knowledge base until the database status is ONLINE or a failed state is reached.

This helper function repeatedly calls retrieve() to check the database_status field.
It will wait for the database to become ONLINE, or raise an exception if it enters
a failed state (DECOMMISSIONED or UNHEALTHY) or if the timeout is exceeded.

Args:
uuid: The knowledge base UUID to poll

timeout: Maximum time to wait in seconds (default: 600 seconds / 10 minutes)

poll_interval: Time to wait between polls in seconds (default: 5 seconds)

extra_headers: Send extra headers

extra_query: Add additional query parameters to the request

extra_body: Add additional JSON properties to the request

Returns:
The final KnowledgeBaseRetrieveResponse when the database status is ONLINE

Raises:
KnowledgeBaseDatabaseError: If the database enters a failed state (DECOMMISSIONED, UNHEALTHY)

APITimeoutError: If the timeout is exceeded before the database becomes ONLINE
"""
if not uuid:
raise ValueError(f"Expected a non-empty value for `uuid` but received {uuid!r}")

start_time = time.time()
failed_states = {"DECOMMISSIONED", "UNHEALTHY"}

while True:
elapsed = time.time() - start_time
if elapsed >= timeout:
raise APITimeoutError(
request=httpx.Request(
method="GET",
url=f"https://api.digitalocean.com/v2/gen-ai/knowledge_bases/{uuid}",
)
)

response = self.retrieve(
uuid,
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
)

status = response.database_status

if status == "ONLINE":
return response

if status in failed_states:
raise KnowledgeBaseDatabaseError(
f"Knowledge base database entered failed state: {status}"
)

# Sleep before next poll, but don't exceed timeout
remaining_time = timeout - elapsed
sleep_time = min(poll_interval, remaining_time)
if sleep_time > 0:
time.sleep(sleep_time)


class AsyncKnowledgeBasesResource(AsyncAPIResource):
@cached_property
Expand Down Expand Up @@ -618,6 +706,85 @@ async def delete(
cast_to=KnowledgeBaseDeleteResponse,
)

async def wait_for_database(
self,
uuid: str,
*,
timeout: float = 600.0,
poll_interval: float = 5.0,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
) -> KnowledgeBaseRetrieveResponse:
"""
Poll the knowledge base until the database status is ONLINE or a failed state is reached.

This helper function repeatedly calls retrieve() to check the database_status field.
It will wait for the database to become ONLINE, or raise an exception if it enters
a failed state (DECOMMISSIONED or UNHEALTHY) or if the timeout is exceeded.

Args:
uuid: The knowledge base UUID to poll

timeout: Maximum time to wait in seconds (default: 600 seconds / 10 minutes)

poll_interval: Time to wait between polls in seconds (default: 5 seconds)

extra_headers: Send extra headers

extra_query: Add additional query parameters to the request

extra_body: Add additional JSON properties to the request

Returns:
The final KnowledgeBaseRetrieveResponse when the database status is ONLINE

Raises:
KnowledgeBaseDatabaseError: If the database enters a failed state (DECOMMISSIONED, UNHEALTHY)

APITimeoutError: If the timeout is exceeded before the database becomes ONLINE
"""
if not uuid:
raise ValueError(f"Expected a non-empty value for `uuid` but received {uuid!r}")

start_time = time.time()
failed_states = {"DECOMMISSIONED", "UNHEALTHY"}

while True:
elapsed = time.time() - start_time
if elapsed >= timeout:
raise APITimeoutError(
request=httpx.Request(
method="GET",
url=f"https://api.digitalocean.com/v2/gen-ai/knowledge_bases/{uuid}",
)
)

response = await self.retrieve(
uuid,
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
)

status = response.database_status

if status == "ONLINE":
return response

if status in failed_states:
raise KnowledgeBaseDatabaseError(
f"Knowledge base database entered failed state: {status}"
)

# Sleep before next poll, but don't exceed timeout
remaining_time = timeout - elapsed
sleep_time = min(poll_interval, remaining_time)
if sleep_time > 0:
await asyncio.sleep(sleep_time)


class KnowledgeBasesResourceWithRawResponse:
def __init__(self, knowledge_bases: KnowledgeBasesResource) -> None:
Expand All @@ -638,6 +805,9 @@ def __init__(self, knowledge_bases: KnowledgeBasesResource) -> None:
self.delete = to_raw_response_wrapper(
knowledge_bases.delete,
)
self.wait_for_database = to_raw_response_wrapper(
knowledge_bases.wait_for_database,
)

@cached_property
def data_sources(self) -> DataSourcesResourceWithRawResponse:
Expand Down Expand Up @@ -667,6 +837,9 @@ def __init__(self, knowledge_bases: AsyncKnowledgeBasesResource) -> None:
self.delete = async_to_raw_response_wrapper(
knowledge_bases.delete,
)
self.wait_for_database = async_to_raw_response_wrapper(
knowledge_bases.wait_for_database,
)

@cached_property
def data_sources(self) -> AsyncDataSourcesResourceWithRawResponse:
Expand Down Expand Up @@ -696,6 +869,9 @@ def __init__(self, knowledge_bases: KnowledgeBasesResource) -> None:
self.delete = to_streamed_response_wrapper(
knowledge_bases.delete,
)
self.wait_for_database = to_streamed_response_wrapper(
knowledge_bases.wait_for_database,
)

@cached_property
def data_sources(self) -> DataSourcesResourceWithStreamingResponse:
Expand Down Expand Up @@ -725,6 +901,9 @@ def __init__(self, knowledge_bases: AsyncKnowledgeBasesResource) -> None:
self.delete = async_to_streamed_response_wrapper(
knowledge_bases.delete,
)
self.wait_for_database = async_to_streamed_response_wrapper(
knowledge_bases.wait_for_database,
)

@cached_property
def data_sources(self) -> AsyncDataSourcesResourceWithStreamingResponse:
Expand Down
Loading