Skip to content

Commit

Permalink
[PoC] use ijson over json-stream
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier committed Jan 31, 2024
1 parent f8d303b commit 7d0d240
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 12 deletions.
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,7 @@ disallow_incomplete_defs = false
module = [
"docx",
"fitz",
"json_stream",
"json_stream.httpx",
"ijson",
"lancedb",
"param",
"pptx",
Expand Down
35 changes: 25 additions & 10 deletions ragna/assistants/_google.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,25 @@
from typing import Iterator
from typing import AsyncIterator

from ragna._compat import anext
from ragna.core import PackageRequirement, Requirement, Source

from ._api import ApiAssistant


class AsyncIteratorReader:
def __init__(self, ait: AsyncIterator[bytes]) -> None:
self._ait = ait

async def read(self, n: int) -> bytes:
if n == 0:
return b""

try:
return await anext(self._ait)
except StopIteration:
return b""


class GoogleApiAssistant(ApiAssistant):
_API_KEY_ENV_VAR = "GOOGLE_API_KEY"
_MODEL: str
Expand Down Expand Up @@ -36,15 +51,12 @@ def _instructize_prompt(self, prompt: str, sources: list[Source]) -> str:
]
)

def _call_api(
async def _call_api(
self, prompt: str, sources: list[Source], *, max_new_tokens: int
) -> Iterator[str]:
import json_stream.httpx
) -> AsyncIterator[str]:
import ijson

# TODO: Use the async client and make this method async as soon when json-stream
# supports async JSON stream.
# See https://github.com/daggaz/json-stream/issues/54
with self._sync_client.stream(
async with self._async_client.stream(
"POST",
f"https://generativelanguage.googleapis.com/v1beta/models/{self._MODEL}:streamGenerateContent",
params={"key": self._api_key},
Expand All @@ -70,8 +82,11 @@ def _call_api(
},
},
) as response:
for chunk in json_stream.httpx.load(response, persistent=True):
yield chunk["candidates"][0]["content"]["parts"][0]["text"]
async for chunk in ijson.items(
AsyncIteratorReader(response.aiter_bytes(10 * 1024)),
"item.candidates.item.content.parts.item.text",
):
yield chunk


class GeminiPro(GoogleApiAssistant):
Expand Down

0 comments on commit 7d0d240

Please sign in to comment.