diff --git a/pyproject.toml b/pyproject.toml index d246fe6c..cc39418d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -142,8 +142,7 @@ disallow_incomplete_defs = false module = [ "docx", "fitz", - "json_stream", - "json_stream.httpx", + "ijson", "lancedb", "param", "pptx", diff --git a/ragna/assistants/_google.py b/ragna/assistants/_google.py index 24bf9071..1b952a25 100644 --- a/ragna/assistants/_google.py +++ b/ragna/assistants/_google.py @@ -1,10 +1,25 @@ -from typing import Iterator +from typing import AsyncIterator +from ragna._compat import anext from ragna.core import PackageRequirement, Requirement, Source from ._api import ApiAssistant +class AsyncIteratorReader: + def __init__(self, ait: AsyncIterator[bytes]) -> None: + self._ait = ait + + async def read(self, n: int) -> bytes: + if n == 0: + return b"" + + try: + return await anext(self._ait) + except StopIteration: + return b"" + + class GoogleApiAssistant(ApiAssistant): _API_KEY_ENV_VAR = "GOOGLE_API_KEY" _MODEL: str @@ -36,15 +51,12 @@ def _instructize_prompt(self, prompt: str, sources: list[Source]) -> str: ] ) - def _call_api( + async def _call_api( self, prompt: str, sources: list[Source], *, max_new_tokens: int - ) -> Iterator[str]: - import json_stream.httpx + ) -> AsyncIterator[str]: + import ijson - # TODO: Use the async client and make this method async as soon when json-stream - # supports async JSON stream. - # See https://github.com/daggaz/json-stream/issues/54 - with self._sync_client.stream( + async with self._async_client.stream( "POST", f"https://generativelanguage.googleapis.com/v1beta/models/{self._MODEL}:streamGenerateContent", params={"key": self._api_key}, @@ -70,8 +82,11 @@ def _call_api( }, }, ) as response: - for chunk in json_stream.httpx.load(response, persistent=True): - yield chunk["candidates"][0]["content"]["parts"][0]["text"] + async for chunk in ijson.items( + AsyncIteratorReader(response.aiter_bytes(10 * 1024)), + "item.candidates.item.content.parts.item.text", + ): + yield chunk class GeminiPro(GoogleApiAssistant):