Skip to content

Commit

Permalink
feat: add rag client (#6844)
Browse files Browse the repository at this point in the history
Co-authored-by: Sergey Kulik <[email protected]>
GitOrigin-RevId: 0b844467557e25a96c83c13a5af01eb294242112
  • Loading branch information
2 people authored and Manul from Pathway committed Jul 5, 2024
1 parent 6d8d90a commit ebd2c09
Show file tree
Hide file tree
Showing 2 changed files with 168 additions and 2 deletions.
168 changes: 167 additions & 1 deletion python/pathway/xpacks/llm/question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
import json
from enum import Enum

import requests

import pathway as pw
from pathway.internals import ColumnReference, Table
from pathway.stdlib.indexing import DataIndex
from pathway.xpacks.llm import Doc, llms, prompts
from pathway.xpacks.llm.llms import prompt_chat_single_qa
from pathway.xpacks.llm.prompts import prompt_qa_geometric_rag
from pathway.xpacks.llm.vector_store import VectorStoreServer
from pathway.xpacks.llm.vector_store import VectorStoreClient, VectorStoreServer


@pw.udf
Expand Down Expand Up @@ -612,3 +614,167 @@ def pw_ai_query(self, pw_ai_queries: pw.Table) -> pw.Table:
)

return result


def send_post_request(
url: str, data: dict, headers: dict = {}, timeout: int | None = None
):
response = requests.post(url, json=data, headers=headers, timeout=timeout)
response.raise_for_status()
return response.json()


class RAGClient:
"""
Connector for interacting with the Pathway RAG applications.
Either (`host` and `port`) or `url` must be set.
Args:
- host: The host of the RAG service.
- port: The port of the RAG service.
- url: The URL of the RAG service.
- timeout: Timeout for requests in seconds. Defaults to 90.
- additional_headers: Additional headers for the requests.
"""

def __init__(
self,
host: str | None = None,
port: int | None = None,
url: str | None = None,
timeout: int | None = 90,
additional_headers: dict | None = None,
):
err = "Either (`host` and `port`) or `url` must be provided, but not both."
if url is not None:
if host is not None or port is not None:
raise ValueError(err)
self.url = url
else:
if host is None:
raise ValueError(err)
port = port or 80

protocol = "https" if port == 443 else "http"
self.url = f"{protocol}://{host}:{port}"

self.timeout = timeout
self.additional_headers = additional_headers or {}

self.index_client = VectorStoreClient(
url=self.url,
timeout=self.timeout,
additional_headers=self.additional_headers,
)

def retrieve(
self,
query: str,
k: int = 3,
metadata_filter: str | None = None,
filepath_globpattern: str | None = None,
):
"""
Retrieve closest documents from the vector store based on a query.
Args:
- query: The query string.
- k: The number of results to retrieve.
- metadata_filter: Optional metadata filter for the documents. Defaults to `None`, which
means there will be no filter.
- filepath_globpattern: Glob pattern for file paths.
"""
return self.index_client.query(
query=query,
k=k,
metadata_filter=metadata_filter,
filepath_globpattern=filepath_globpattern,
)

def statistics(
self,
):
"""
Retrieve stats from the vector store.
"""
return self.index_client.get_vectorstore_statistics()

def pw_ai_answer(
self,
prompt: str,
filters: str | None = None,
model: str | None = None,
):
"""
Return RAG answer based on a given prompt and optional filter.
Args:
- prompt: Question to be asked.
- filters: Optional metadata filter for the documents. Defaults to ``None``, which
means there will be no filter.
- model: Optional LLM model. If ``None``, app default will be used by the server.
"""
api_url = f"{self.url}/v1/pw_ai_answer"
payload = {
"prompt": prompt,
}

if filters:
payload["filters"] = filters

if model:
payload["model"] = model

response = send_post_request(api_url, payload, self.additional_headers)
return response

def pw_ai_summary(
self,
text_list: list[str],
model: str | None = None,
):
"""
Summarize a list of texts.
Args:
- text_list: List of texts to summarize.
- model: Optional LLM model. If ``None``, app default will be used by the server.
"""
api_url = f"{self.url}/v1/pw_ai_summary"
payload: dict = {
"text_list": text_list,
}

if model:
payload["model"] = model

response = send_post_request(api_url, payload, self.additional_headers)
return response

def pw_list_documents(self, filters: str | None = None, keys: list[str] = ["path"]):
"""
List indexed documents from the vector store with optional filtering.
Args:
- filters: Optional metadata filter for the documents.
- keys: List of metadata keys to be included in the response.
Defaults to ``["path"]``. Setting to ``None`` will retrieve all available metadata.
"""
api_url = f"{self.url}/v1/pw_list_documents"
payload = {}

if filters:
payload["metadata_filter"] = filters

response: list[dict] = send_post_request(
api_url, payload, self.additional_headers
)

if response:
if keys:
result = [{k: v for k, v in dc.items() if k in keys} for dc in response]
else:
result = response
else:
result = []
return result
2 changes: 1 addition & 1 deletion python/pathway/xpacks/llm/vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,7 @@ def __init__(
host: str | None = None,
port: int | None = None,
url: str | None = None,
timeout: int = 15,
timeout: int | None = 15,
additional_headers: dict | None = None,
):
err = "Either (`host` and `port`) or `url` must be provided, but not both."
Expand Down

0 comments on commit ebd2c09

Please sign in to comment.