Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Google assistants #301

Merged
merged 13 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/references/faq.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,11 @@
```bash
export MOSAICML_API_KEY="XXXXX"
```

### [Google](https://ai.google.dev/)

1. ADDME
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nenb Could you fill out this similar to what we did for the others?

pmeier marked this conversation as resolved.
Show resolved Hide resolved
2. Set the `GOOGLE_API_KEY` environment variable with your Google API key:
```bash
export GOOGLE_API_KEY="XXXXX"
```
8 changes: 5 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ dependencies = [
"emoji",
"fastapi",
"httpx",
"httpx_sse",
"httpx-sse",
"importlib_metadata>=4.6; python_version<'3.10'",
"packaging",
"panel>=1.3.6,<1.4",
Expand Down Expand Up @@ -56,6 +56,7 @@ Repository = "https://github.com/Quansight/ragna"
# to update the array below, run scripts/update_optional_dependencies.py
all = [
"chromadb>=0.4.13",
"ijson",
"lancedb>=0.2",
"pyarrow",
"pymupdf>=1.23.6",
Expand Down Expand Up @@ -139,12 +140,13 @@ disallow_incomplete_defs = false

[[tool.mypy.overrides]]
module = [
"docx",
"fitz",
"ijson",
"lancedb",
"param",
"pyarrow",
"docx",
"pptx",
"pyarrow",
"sentence_transformers",
]
ignore_missing_imports = true
Expand Down
19 changes: 16 additions & 3 deletions ragna/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,22 @@ def _anext() -> Callable[[AsyncIterator[T]], Awaitable[T]]:
if sys.version_info[:2] >= (3, 10):
anext = builtins.anext
else:

async def anext(ait: AsyncIterator[T]) -> T:
return await ait.__anext__()
sentinel = object()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a slight refactor as I needed the default return value in case of exhaustion.


def anext(
ait: AsyncIterator[T],
default: T = sentinel, # type: ignore[assignment]
) -> Awaitable[T]:
if default is sentinel:
return ait.__anext__()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't fully grok this. Do we not need to await this? And in what situation will this arise?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't fully grok this. Do we not need to await this?

We don't and we actually can't here. Note that anext is not an async def function. anext returns an awaitable, e.g.

async_iterator = ...
awaitable = anext(async_iterator)
result = await awaitable

And in what situation will this arise?

This is the default case, i.e. no default value is set. Let me refactor this function to make it more clear to what is going on.


async def anext_with_default() -> T:
try:
return await ait.__anext__()
except StopAsyncIteration:
return default

return anext_with_default()

return anext

Expand Down
3 changes: 3 additions & 0 deletions ragna/assistants/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
__all__ = [
"Claude",
"ClaudeInstant",
"GeminiPro",
"GeminiUltra",
"Gpt35Turbo16k",
"Gpt4",
"Mpt7bInstruct",
Expand All @@ -10,6 +12,7 @@

from ._anthropic import Claude, ClaudeInstant
from ._demo import RagnaDemoAssistant
from ._google import GeminiPro, GeminiUltra
from ._mosaicml import Mpt7bInstruct, Mpt30bInstruct
from ._openai import Gpt4, Gpt35Turbo16k

Expand Down
3 changes: 1 addition & 2 deletions ragna/assistants/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ def requirements(cls) -> list[Requirement]:

def __init__(self) -> None:
self._client = httpx.AsyncClient(
headers={"User-Agent": f"{ragna.__version__}/{self}"},
timeout=60,
headers={"User-Agent": f"{ragna.__version__}/{self}"}, timeout=60
)
self._api_key = os.environ[self._API_KEY_ENV_VAR]

Expand Down
117 changes: 117 additions & 0 deletions ragna/assistants/_google.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from typing import AsyncIterator

from ragna._compat import anext
from ragna.core import PackageRequirement, Requirement, Source

from ._api import ApiAssistant


class AsyncIteratorReader:
def __init__(self, ait: AsyncIterator[bytes]) -> None:
self._ait = ait

async def read(self, n: int) -> bytes:
if n == 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: could you add some documentation here on why the n arg is required/used? I get that ijson expects a file-like object, but does this also imply the n arg, or is it something specific to ijson.

return b""
return await anext(self._ait, b"") # type: ignore[call-arg]


class GoogleApiAssistant(ApiAssistant):
_API_KEY_ENV_VAR = "GOOGLE_API_KEY"
_MODEL: str
_CONTEXT_SIZE: int

@classmethod
def requirements(cls) -> list[Requirement]:
return [
*super().requirements(),
PackageRequirement("ijson"),
]

@classmethod
def display_name(cls) -> str:
return f"Google/{cls._MODEL}"

@property
def max_input_size(self) -> int:
return self._CONTEXT_SIZE

def _instructize_prompt(self, prompt: str, sources: list[Source]) -> str:
# https://ai.google.dev/docs/prompt_best_practices#add-contextual-information
return "\n".join(
[
"Answer the prompt using only the pieces of context below.",
"If you don't know the answer, just say so. Don't try to make up additional context.",
f"Prompt: {prompt}",
*[f"\n{source.content}" for source in sources],
]
)

async def _call_api(
self, prompt: str, sources: list[Source], *, max_new_tokens: int
) -> AsyncIterator[str]:
import ijson

async with self._client.stream(
"POST",
f"https://generativelanguage.googleapis.com/v1beta/models/{self._MODEL}:streamGenerateContent",
params={"key": self._api_key},
headers={"Content-Type": "application/json"},
json={
"contents": [
{"parts": [{"text": self._instructize_prompt(prompt, sources)}]}
],
# https://ai.google.dev/docs/safety_setting_gemini
"safetySettings": [
{"category": f"HARM_CATEGORY_{category}", "threshold": "BLOCK_NONE"}
for category in [
"HARASSMENT",
"HATE_SPEECH",
"SEXUALLY_EXPLICIT",
"DANGEROUS_CONTENT",
]
],
# https://ai.google.dev/tutorials/rest_quickstart#configuration
"generationConfig": {
"temperature": 0.0,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are going to hard-code this then I would suggest a higher-value, as this is what most users will require.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is the other way around: we want to hardcode 0.0 here, because that means determinism. If we learned one thing from trying to bring RAG to businesses is that they want to get exactly the same answer if they ask the same question twice. Of course we can't guarantee it since we don't control the model, but we can do our best to at least avoid sampling during generation.

This is the same for all other assistants that we currently have

"temperature": 0.0,

"parameters": {"temperature": 0.0, "max_new_tokens": max_new_tokens},

"temperature": 0.0,

"maxOutputTokens": max_new_tokens,
},
},
) as response:
async for chunk in ijson.items(
AsyncIteratorReader(response.aiter_bytes(1024)),
"item.candidates.item.content.parts.item.text",
):
yield chunk


class GeminiPro(GoogleApiAssistant):
"""[Google Gemini Pro](https://ai.google.dev/models/gemini)

!!! info "Required environment variables"

- `GOOGLE_API_KEY`

!!! info "Required packages"

- `ijson`
"""

_MODEL = "gemini-pro"
_CONTEXT_SIZE = 30_720


class GeminiUltra(GoogleApiAssistant):
"""[Google Gemini Ultra](https://ai.google.dev/models/gemini)

!!! info "Required environment variables"

- `GOOGLE_API_KEY`

!!! info "Required packages"

- `ijson`
"""

_MODEL = "gemini-ultra"
_CONTEXT_SIZE = 30_720
2 changes: 1 addition & 1 deletion ragna/deploy/_ui/api_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class ApiWrapper(param.Parameterized):
auth_token = param.String(default=None)

def __init__(self, api_url, **params):
self.client = httpx.AsyncClient(base_url=api_url)
self.client = httpx.AsyncClient(base_url=api_url, timeout=60)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even when streaming, the Google assistants return really large chunks and thus easily go over the default timeout. The new timeout is in line with what we use for our builtin assistants as well:

self._client = httpx.AsyncClient(
headers={"User-Agent": f"{ragna.__version__}/{self}"},
timeout=60,
)


super().__init__(**params)

Expand Down
2 changes: 2 additions & 0 deletions requirements-docker.lock
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ idna==3.6
# anyio
# httpx
# requests
ijson==3.2.3
# via Ragna (pyproject.toml)
importlib-metadata==6.11.0
# via opentelemetry-api
importlib-resources==6.1.1
Expand Down
Loading