-
Notifications
You must be signed in to change notification settings - Fork 25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add Google assistants #301
Changes from 8 commits
dad1024
74fd1e0
859d85a
f8d303b
7d0d240
61f60ed
2bc4f35
a1d34a5
9d6de3c
6f1f9e5
40d8a59
423c913
1e62e04
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -53,9 +53,22 @@ def _anext() -> Callable[[AsyncIterator[T]], Awaitable[T]]: | |
if sys.version_info[:2] >= (3, 10): | ||
anext = builtins.anext | ||
else: | ||
|
||
async def anext(ait: AsyncIterator[T]) -> T: | ||
return await ait.__anext__() | ||
sentinel = object() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a slight refactor as I needed the default return value in case of exhaustion. |
||
|
||
def anext( | ||
ait: AsyncIterator[T], | ||
default: T = sentinel, # type: ignore[assignment] | ||
) -> Awaitable[T]: | ||
if default is sentinel: | ||
return ait.__anext__() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't fully grok this. Do we not need to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
We don't and we actually can't here. Note that async_iterator = ...
awaitable = anext(async_iterator)
result = await awaitable
This is the default case, i.e. no |
||
|
||
async def anext_with_default() -> T: | ||
try: | ||
return await ait.__anext__() | ||
except StopAsyncIteration: | ||
return default | ||
|
||
return anext_with_default() | ||
|
||
return anext | ||
|
||
|
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,117 @@ | ||||||||
from typing import AsyncIterator | ||||||||
|
||||||||
from ragna._compat import anext | ||||||||
from ragna.core import PackageRequirement, Requirement, Source | ||||||||
|
||||||||
from ._api import ApiAssistant | ||||||||
|
||||||||
|
||||||||
class AsyncIteratorReader: | ||||||||
def __init__(self, ait: AsyncIterator[bytes]) -> None: | ||||||||
self._ait = ait | ||||||||
|
||||||||
async def read(self, n: int) -> bytes: | ||||||||
if n == 0: | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nitpick: could you add some documentation here on why the |
||||||||
return b"" | ||||||||
return await anext(self._ait, b"") # type: ignore[call-arg] | ||||||||
|
||||||||
|
||||||||
class GoogleApiAssistant(ApiAssistant): | ||||||||
_API_KEY_ENV_VAR = "GOOGLE_API_KEY" | ||||||||
_MODEL: str | ||||||||
_CONTEXT_SIZE: int | ||||||||
|
||||||||
@classmethod | ||||||||
def requirements(cls) -> list[Requirement]: | ||||||||
return [ | ||||||||
*super().requirements(), | ||||||||
PackageRequirement("ijson"), | ||||||||
] | ||||||||
|
||||||||
@classmethod | ||||||||
def display_name(cls) -> str: | ||||||||
return f"Google/{cls._MODEL}" | ||||||||
|
||||||||
@property | ||||||||
def max_input_size(self) -> int: | ||||||||
return self._CONTEXT_SIZE | ||||||||
|
||||||||
def _instructize_prompt(self, prompt: str, sources: list[Source]) -> str: | ||||||||
# https://ai.google.dev/docs/prompt_best_practices#add-contextual-information | ||||||||
return "\n".join( | ||||||||
[ | ||||||||
"Answer the prompt using only the pieces of context below.", | ||||||||
"If you don't know the answer, just say so. Don't try to make up additional context.", | ||||||||
f"Prompt: {prompt}", | ||||||||
*[f"\n{source.content}" for source in sources], | ||||||||
] | ||||||||
) | ||||||||
|
||||||||
async def _call_api( | ||||||||
self, prompt: str, sources: list[Source], *, max_new_tokens: int | ||||||||
) -> AsyncIterator[str]: | ||||||||
import ijson | ||||||||
|
||||||||
async with self._client.stream( | ||||||||
"POST", | ||||||||
f"https://generativelanguage.googleapis.com/v1beta/models/{self._MODEL}:streamGenerateContent", | ||||||||
params={"key": self._api_key}, | ||||||||
headers={"Content-Type": "application/json"}, | ||||||||
json={ | ||||||||
"contents": [ | ||||||||
{"parts": [{"text": self._instructize_prompt(prompt, sources)}]} | ||||||||
], | ||||||||
# https://ai.google.dev/docs/safety_setting_gemini | ||||||||
"safetySettings": [ | ||||||||
{"category": f"HARM_CATEGORY_{category}", "threshold": "BLOCK_NONE"} | ||||||||
for category in [ | ||||||||
"HARASSMENT", | ||||||||
"HATE_SPEECH", | ||||||||
"SEXUALLY_EXPLICIT", | ||||||||
"DANGEROUS_CONTENT", | ||||||||
] | ||||||||
], | ||||||||
# https://ai.google.dev/tutorials/rest_quickstart#configuration | ||||||||
"generationConfig": { | ||||||||
"temperature": 0.0, | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we are going to hard-code this then I would suggest a higher-value, as this is what most users will require. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is the other way around: we want to hardcode This is the same for all other assistants that we currently have ragna/ragna/assistants/_anthropic.py Line 52 in 1b53e62
ragna/ragna/assistants/_mosaicml.py Line 43 in 1b53e62
ragna/ragna/assistants/_openai.py Line 58 in 1b53e62
|
||||||||
"maxOutputTokens": max_new_tokens, | ||||||||
}, | ||||||||
}, | ||||||||
) as response: | ||||||||
async for chunk in ijson.items( | ||||||||
AsyncIteratorReader(response.aiter_bytes(1024)), | ||||||||
"item.candidates.item.content.parts.item.text", | ||||||||
): | ||||||||
yield chunk | ||||||||
|
||||||||
|
||||||||
class GeminiPro(GoogleApiAssistant): | ||||||||
"""[Google Gemini Pro](https://ai.google.dev/models/gemini) | ||||||||
|
||||||||
!!! info "Required environment variables" | ||||||||
|
||||||||
- `GOOGLE_API_KEY` | ||||||||
|
||||||||
!!! info "Required packages" | ||||||||
|
||||||||
- `ijson` | ||||||||
""" | ||||||||
|
||||||||
_MODEL = "gemini-pro" | ||||||||
_CONTEXT_SIZE = 30_720 | ||||||||
|
||||||||
|
||||||||
class GeminiUltra(GoogleApiAssistant): | ||||||||
"""[Google Gemini Ultra](https://ai.google.dev/models/gemini) | ||||||||
|
||||||||
!!! info "Required environment variables" | ||||||||
|
||||||||
- `GOOGLE_API_KEY` | ||||||||
|
||||||||
!!! info "Required packages" | ||||||||
|
||||||||
- `ijson` | ||||||||
""" | ||||||||
|
||||||||
_MODEL = "gemini-ultra" | ||||||||
_CONTEXT_SIZE = 30_720 |
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -18,7 +18,7 @@ class ApiWrapper(param.Parameterized): | |||||||||
auth_token = param.String(default=None) | ||||||||||
|
||||||||||
def __init__(self, api_url, **params): | ||||||||||
self.client = httpx.AsyncClient(base_url=api_url) | ||||||||||
self.client = httpx.AsyncClient(base_url=api_url, timeout=60) | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Even when streaming, the Google assistants return really large chunks and thus easily go over the default timeout. The new timeout is in line with what we use for our builtin assistants as well: ragna/ragna/assistants/_api.py Lines 19 to 22 in 1b53e62
|
||||||||||
|
||||||||||
super().__init__(**params) | ||||||||||
|
||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@nenb Could you fill out this similar to what we did for the others?