Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Google assistants #301

Merged
merged 13 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,14 @@ disallow_incomplete_defs = false

[[tool.mypy.overrides]]
module = [
"docx",
"fitz",
"json_stream",
"json_stream.httpx",
pmeier marked this conversation as resolved.
Show resolved Hide resolved
"lancedb",
"param",
"pyarrow",
"docx",
"pptx",
"pyarrow",
"sentence_transformers",
]
ignore_missing_imports = true
Expand Down
42 changes: 41 additions & 1 deletion ragna/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,23 @@
import sys
import threading
from pathlib import Path
from typing import Any, Callable, Optional, Union
from typing import (
Any,
AsyncIterator,
Awaitable,
Callable,
Iterator,
Optional,
TypeVar,
Union,
cast,
)
from urllib.parse import SplitResult, urlsplit, urlunsplit

from starlette.concurrency import iterate_in_threadpool, run_in_threadpool

T = TypeVar("T")

_LOCAL_ROOT = (
Path(os.environ.get("RAGNA_LOCAL_ROOT", "~/.cache/ragna")).expanduser().resolve()
)
Expand Down Expand Up @@ -125,3 +139,29 @@ def is_debugging() -> bool:
if any(part.startswith(name) for part in parts):
return True
return False


def as_awaitable(
fn: Union[Callable[..., T], Callable[..., Awaitable[T]]],
*args: Any,
**kwargs: Any,
) -> Awaitable[T]:
if inspect.iscoroutinefunction(fn):
fn = cast(Callable[..., Awaitable[T]], fn)
return fn(*args, **kwargs)
else:
fn = cast(Callable[..., T], fn)
return run_in_threadpool(fn, *args, **kwargs)


def as_async_iterator(
fn: Union[Callable[..., Iterator[T]], Callable[..., AsyncIterator[T]]],
*args: Any,
**kwargs: Any,
) -> AsyncIterator[T]:
if inspect.isasyncgenfunction(fn):
fn = cast(Callable[..., AsyncIterator[T]], fn)
return fn(*args, **kwargs)
else:
fn = cast(Callable[..., Iterator[T]], fn)
return iterate_in_threadpool(fn(*args, **kwargs))
3 changes: 3 additions & 0 deletions ragna/assistants/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
__all__ = [
"Claude",
"ClaudeInstant",
"GeminiPro",
"GeminiUltra",
"Gpt35Turbo16k",
"Gpt4",
"Mpt7bInstruct",
Expand All @@ -10,6 +12,7 @@

from ._anthropic import Claude, ClaudeInstant
from ._demo import RagnaDemoAssistant
from ._google import GeminiPro, GeminiUltra
from ._mosaicml import Mpt7bInstruct, Mpt30bInstruct
from ._openai import Gpt4, Gpt35Turbo16k

Expand Down
2 changes: 1 addition & 1 deletion ragna/assistants/_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ async def _call_api(
) -> AsyncIterator[str]:
# See https://docs.anthropic.com/claude/reference/streaming
async with httpx_sse.aconnect_sse(
self._client,
self._async_client,
"POST",
"https://api.anthropic.com/v1/complete",
headers={
Expand Down
21 changes: 14 additions & 7 deletions ragna/assistants/_api.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import abc
import os
from typing import AsyncIterator
from typing import Any, AsyncIterator, Iterator

import httpx

import ragna
from ragna._utils import as_async_iterator
from ragna.core import Assistant, EnvVarRequirement, Requirement, Source


Expand All @@ -16,22 +17,28 @@ def requirements(cls) -> list[Requirement]:
return [EnvVarRequirement(cls._API_KEY_ENV_VAR)]

def __init__(self) -> None:
self._client = httpx.AsyncClient(
self._api_key = os.environ[self._API_KEY_ENV_VAR]

kwargs: dict[str, Any] = dict(
headers={"User-Agent": f"{ragna.__version__}/{self}"},
timeout=60,
)
self._api_key = os.environ[self._API_KEY_ENV_VAR]
self._sync_client = httpx.Client(**kwargs)
self._async_client = httpx.AsyncClient(**kwargs)
pmeier marked this conversation as resolved.
Show resolved Hide resolved

async def answer(
self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256
) -> AsyncIterator[str]:
async for chunk in self._call_api( # type: ignore[attr-defined, misc]
prompt, sources, max_new_tokens=max_new_tokens
async for chunk in as_async_iterator(
pmeier marked this conversation as resolved.
Show resolved Hide resolved
self._call_api,
prompt,
sources,
max_new_tokens=max_new_tokens,
):
yield chunk

@abc.abstractmethod
async def _call_api(
def _call_api(
pmeier marked this conversation as resolved.
Show resolved Hide resolved
self, prompt: str, sources: list[Source], *, max_new_tokens: int
) -> AsyncIterator[str]:
) -> Iterator[str]:
...
95 changes: 95 additions & 0 deletions ragna/assistants/_google.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from typing import Iterator

from ragna.core import PackageRequirement, Requirement, Source

from ._api import ApiAssistant


class GoogleApiAssistant(ApiAssistant):
_API_KEY_ENV_VAR = "GOOGLE_API_KEY"
_MODEL: str
_CONTEXT_SIZE: int

@classmethod
def requirements(cls) -> list[Requirement]:
return [
*super().requirements(),
PackageRequirement("json-stream"),
]

@classmethod
def display_name(cls) -> str:
return f"Google/{cls._MODEL}"

@property
def max_input_size(self) -> int:
return self._CONTEXT_SIZE

def _instructize_prompt(self, prompt: str, sources: list[Source]) -> str:
# https://ai.google.dev/docs/prompt_best_practices#add-contextual-information
return "\n".join(
[
"Answer the prompt using only the pieces of context below.",
"If you don't know the answer, just say so. Don't try to make up additional context.",
f"Prompt: {prompt}",
*[f"\n{source.content}" for source in sources],
]
)

def _call_api(
pmeier marked this conversation as resolved.
Show resolved Hide resolved
self, prompt: str, sources: list[Source], *, max_new_tokens: int
) -> Iterator[str]:
import json_stream.httpx

with self._sync_client.stream(
"POST",
f"https://generativelanguage.googleapis.com/v1beta/models/{self._MODEL}:streamGenerateContent",
params={"key": self._api_key},
headers={"Content-Type": "application/json"},
json={
"contents": [
{"parts": [{"text": self._instructize_prompt(prompt, sources)}]}
],
# https://ai.google.dev/docs/safety_setting_gemini
"safetySettings": [
{"category": f"HARM_CATEGORY_{category}", "threshold": "BLOCK_NONE"}
for category in [
"HARASSMENT",
"HATE_SPEECH",
"SEXUALLY_EXPLICIT",
"DANGEROUS_CONTENT",
]
],
# https://ai.google.dev/tutorials/rest_quickstart#configuration
"generationConfig": {
"temperature": 0.0,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are going to hard-code this then I would suggest a higher-value, as this is what most users will require.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is the other way around: we want to hardcode 0.0 here, because that means determinism. If we learned one thing from trying to bring RAG to businesses is that they want to get exactly the same answer if they ask the same question twice. Of course we can't guarantee it since we don't control the model, but we can do our best to at least avoid sampling during generation.

This is the same for all other assistants that we currently have

"temperature": 0.0,

"parameters": {"temperature": 0.0, "max_new_tokens": max_new_tokens},

"temperature": 0.0,

"maxOutputTokens": max_new_tokens,
},
},
) as response:
for chunk in json_stream.httpx.load(response, persistent=True):
yield chunk["candidates"][0]["content"]["parts"][0]["text"]


class GeminiPro(GoogleApiAssistant):
"""[Google Gemini Pro](https://ai.google.dev/models/gemini)

!!! info "Required environment variables"

- `GOOGLE_API_KEY`
"""

_MODEL = "gemini-pro"
_CONTEXT_SIZE = 30_720


class GeminiUltra(GoogleApiAssistant):
"""[Google Gemini Ultra](https://ai.google.dev/models/gemini)

!!! info "Required environment variables"

- `GOOGLE_API_KEY`
"""

_MODEL = "gemini-ultra"
_CONTEXT_SIZE = 30_720
2 changes: 1 addition & 1 deletion ragna/assistants/_mosaicml.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ async def _call_api(
) -> AsyncIterator[str]:
instruction = self._instructize_prompt(prompt, sources)
# https://docs.mosaicml.com/en/latest/inference.html#text-completion-requests
response = await self._client.post(
response = await self._async_client.post(
f"https://models.hosted-on.mosaicml.hosting/{self._MODEL}/v1/predict",
headers={
"Authorization": f"{self._api_key}",
Expand Down
2 changes: 1 addition & 1 deletion ragna/assistants/_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ async def _call_api(
# See https://platform.openai.com/docs/api-reference/chat/create
# and https://platform.openai.com/docs/api-reference/chat/streaming
async with httpx_sse.aconnect_sse(
self._client,
self._async_client,
"POST",
"https://api.openai.com/v1/chat/completions",
headers={
Expand Down
29 changes: 7 additions & 22 deletions ragna/core/_rag.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import datetime
import inspect
import uuid
from typing import (
Any,
Expand All @@ -19,7 +18,8 @@
)

import pydantic
from starlette.concurrency import iterate_in_threadpool, run_in_threadpool

from ragna._utils import as_async_iterator, as_awaitable

from ._components import Assistant, Component, Message, MessageRole, SourceStorage
from ._document import Document, LocalDocument
Expand Down Expand Up @@ -256,34 +256,19 @@ def _unpack_chat_params(
for fn, model in component_models.items()
}

async def _run(
def _run(
self, fn: Union[Callable[..., T], Callable[..., Awaitable[T]]], *args: Any
) -> T:
) -> Awaitable[T]:
kwargs = self._unpacked_params[fn]
if inspect.iscoroutinefunction(fn):
fn = cast(Callable[..., Awaitable[T]], fn)
coro = fn(*args, **kwargs)
else:
fn = cast(Callable[..., T], fn)
coro = run_in_threadpool(fn, *args, **kwargs)

return await coro
return as_awaitable(fn, *args, **kwargs)
pmeier marked this conversation as resolved.
Show resolved Hide resolved

async def _run_gen(
def _run_gen(
self,
fn: Union[Callable[..., Iterator[T]], Callable[..., AsyncIterator[T]]],
*args: Any,
) -> AsyncIterator[T]:
kwargs = self._unpacked_params[fn]
if inspect.isasyncgenfunction(fn):
fn = cast(Callable[..., AsyncIterator[T]], fn)
async_gen = fn(*args, **kwargs)
else:
fn = cast(Callable[..., Iterator[T]], fn)
async_gen = iterate_in_threadpool(fn(*args, **kwargs))

async for item in async_gen:
yield item
return as_async_iterator(fn, *args, **kwargs)

async def __aenter__(self) -> Chat:
await self.prepare()
Expand Down
2 changes: 1 addition & 1 deletion ragna/deploy/_ui/api_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class ApiWrapper(param.Parameterized):
auth_token = param.String(default=None)

def __init__(self, api_url, **params):
self.client = httpx.AsyncClient(base_url=api_url)
self.client = httpx.AsyncClient(base_url=api_url, timeout=60)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even when streaming, the Google assistants return really large chunks and thus easily go over the default timeout. The new timeout is in line with what we use for our builtin assistants as well:

self._client = httpx.AsyncClient(
headers={"User-Agent": f"{ragna.__version__}/{self}"},
timeout=60,
)


super().__init__(**params)

Expand Down
Loading