Skip to content

Commit

Permalink
chore(wren-ai-service): retrieval improvement (#599)
Browse files Browse the repository at this point in the history
* refactor

* add table description to indexing

* remove print

* rename

* fix conflict

* refine indexing to have columns per batch

* update

* fix

* fix but

* fix bug

* fix

* fix

* update

* update

* fix bug

* update top k

* fix eval

* fix bug

* update

* update

* update

* skip if 0 documents

* fix top k

* delete documents by filters instead of ids

* fix bug

* update

* update

* update

* fix test

* add COLUMN_INDEXING_BATCH_SIZE

* update

* fix bug

* update

* fix bug
  • Loading branch information
cyyeh authored Aug 30, 2024
1 parent 19e47f7 commit 4e0acc9
Show file tree
Hide file tree
Showing 20 changed files with 815 additions and 332 deletions.
3 changes: 3 additions & 0 deletions docker/.env.ai.example
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
LLM_PROVIDER=openai_llm # openai_llm, azure_openai_llm, ollama_llm
GENERATION_MODEL=gpt-4o-mini
GENERATION_MODEL_KWARGS={"temperature": 0, "n": 1, "max_tokens": 4096, "response_format": {"type": "json_object"}}
COLUMN_INDEXING_BATCH_SIZE=50
TABLE_RETRIEVAL_SIZE=10
TABLE_COLUMN_RETRIEVAL_SIZE=1000

# openai or openai-api-compatible
LLM_OPENAI_API_KEY=sk-xxxx
Expand Down
5 changes: 4 additions & 1 deletion wren-ai-service/.env.dev.example
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
WREN_AI_SERVICE_HOST=127.0.0.1
WREN_AI_SERVICE_PORT=5556
SHOULD_FORCE_DEPLOY=
COLUMN_INDEXING_BATCH_SIZE=50
TABLE_RETRIEVAL_SIZE=10
TABLE_COLUMN_RETRIEVAL_SIZE=1000

## LLM
LLM_PROVIDER=openai_llm # openai_llm, azure_openai_llm, ollama_llm
Expand Down Expand Up @@ -64,7 +67,7 @@ WREN_ENGINE_MANIFEST=
# Evaluation
DATASET_NAME=book_2

LANGFUSE_ENABLE=false # true/false
LANGFUSE_ENABLE=
LANGFUSE_SECRET_KEY=
LANGFUSE_PUBLIC_KEY=
LANGFUSE_HOST=https://cloud.langfuse.com
Expand Down
6 changes: 0 additions & 6 deletions wren-ai-service/demo/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
save_mdl_json_file,
show_asks_details_results,
show_asks_results,
show_er_diagram,
update_llm,
)

Expand Down Expand Up @@ -145,11 +144,6 @@ def onchange_llm_model():
expanded=False,
)

show_er_diagram(
st.session_state["mdl_json"]["models"],
st.session_state["mdl_json"]["relationships"],
)

deploy_ok = st.button(
"Deploy",
use_container_width=True,
Expand Down
47 changes: 0 additions & 47 deletions wren-ai-service/demo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import copy
import json
import os
import re
import time
from pathlib import Path
from typing import Dict, List, Optional, Tuple
Expand Down Expand Up @@ -179,52 +178,6 @@ def get_data_from_wren_engine(


# ui related
def show_er_diagram(models: List[dict], relationships: List[dict]):
# Start of the Graphviz syntax
graphviz = "digraph ERD {\n"
graphviz += ' graph [pad="0.5", nodesep="0.5", ranksep="2"];\n'
graphviz += " node [shape=plain]\n"
graphviz += " rankdir=LR;\n\n"

# Function to format the label for Graphviz
def format_label(name, columns):
label = f'<<TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0"><TR><TD><B>{name}</B></TD></TR>'
for column in columns:
label += f'<TR><TD>{column["name"]} : {column["type"]}</TD></TR>'
label += "</TABLE>>"
return label

# Add models (entities) to the Graphviz syntax
for model in models:
graphviz += f' {model["name"]} [label={format_label(model["name"], model["columns"])}];\n'

graphviz += "\n"

# Extract columns involved in each relationship
def extract_columns(condition):
# This regular expression should match the condition format and extract column names
matches = re.findall(r"(\w+)\.(\w+) = (\w+)\.(\w+)", condition)
if matches:
return matches[0][1], matches[0][3] # Returns (from_column, to_column)
return "", ""

# Add relationships to the Graphviz syntax
for relationship in relationships:
from_model, to_model = relationship["models"]
from_column, to_column = extract_columns(relationship["condition"])
label = (
f'{relationship["name"]}\\n({from_column} to {to_column}) ({relationship['joinType']})'
if from_column and to_column
else relationship["name"]
)
graphviz += f' {from_model} -> {to_model} [label="{label}"];\n'

graphviz += "}"

st.markdown("ER Diagram")
st.graphviz_chart(graphviz)


def show_query_history():
if st.session_state["query_history"]:
with st.expander("Query History", expanded=False):
Expand Down
4 changes: 2 additions & 2 deletions wren-ai-service/eval/data_curation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def _build_partial_mdl_json(
"content": ddl_command,
}
for new_mdl_json in new_mdl_jsons
for i, ddl_command in enumerate(ddl_converter.get_ddl_commands(new_mdl_json))
for i, ddl_command in enumerate(ddl_converter._get_ddl_commands(new_mdl_json))
]


Expand Down Expand Up @@ -274,7 +274,7 @@ async def get_question_sql_pairs(
{custom_instructions}
### Input ###
Data Model: {"\n\n".join(ddl_converter.get_ddl_commands(mdl_json))}
Data Model: {"\n\n".join(ddl_converter._get_ddl_commands(mdl_json))}
Generate {num_pairs} of the questions and corresponding SQL queries according to the Output Format in JSON
Think step by step
Expand Down
26 changes: 12 additions & 14 deletions wren-ai-service/eval/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def parse_ddl(ddl: str) -> list:

columns = []
for doc in docs:
columns.extend(parse_ddl(doc["content"]))
columns.extend(parse_ddl(doc))
return columns


Expand Down Expand Up @@ -172,6 +172,7 @@ def __init__(
self,
meta: dict,
mdl: dict,
llm_provider: LLMProvider,
embedder_provider: EmbedderProvider,
document_store_provider: DocumentStoreProvider,
**kwargs,
Expand All @@ -186,17 +187,17 @@ def __init__(
deploy_model(mdl, _indexing)

self._retrieval = retrieval.Retrieval(
llm_provider=llm_provider,
embedder_provider=embedder_provider,
document_store_provider=document_store_provider,
table_retrieval_size=meta["table_retrieval_size"],
table_column_retrieval_size=meta["table_column_retrieval_size"],
)

async def _process(self, prediction: dict, **_) -> dict:
result = await self._retrieval.run(query=prediction["input"])
documents = result.get("retrieval", {}).get("documents", [])

prediction["retrieval_context"] = extract_units(
[doc.to_dict() for doc in documents]
)
documents = result.get("construct_retrieval_results", [])
prediction["retrieval_context"] = extract_units(documents)

return prediction

Expand Down Expand Up @@ -241,17 +242,15 @@ async def _flat(self, prediction: dict, actual: str) -> dict:
return prediction

async def _process(self, prediction: dict, document: list, **_) -> dict:
documents = [Document.from_dict(doc) for doc in document]
documents = [Document.from_dict(doc).content for doc in document]
actual_output = await self._generation.run(
query=prediction["input"],
contexts=documents,
exclude=[],
)

prediction["actual_output"] = actual_output
prediction["retrieval_context"] = extract_units(
[doc.to_dict() for doc in documents]
)
prediction["retrieval_context"] = extract_units(documents)

return prediction

Expand Down Expand Up @@ -302,6 +301,7 @@ def __init__(

self._mdl = mdl
self._retrieval = retrieval.Retrieval(
llm_provider=llm_provider,
embedder_provider=embedder_provider,
document_store_provider=document_store_provider,
)
Expand All @@ -319,17 +319,15 @@ async def _flat(self, prediction: dict, actual: str) -> dict:

async def _process(self, prediction: dict, **_) -> dict:
result = await self._retrieval.run(query=prediction["input"])
documents = result.get("retrieval", {}).get("documents", [])
documents = result.get("construct_retrieval_results", [])
actual_output = await self._generation.run(
query=prediction["input"],
contexts=documents,
exclude=[],
)

prediction["actual_output"] = actual_output
prediction["retrieval_context"] = extract_units(
[doc.to_dict() for doc in documents]
)
prediction["retrieval_context"] = extract_units(documents)

return prediction

Expand Down
19 changes: 17 additions & 2 deletions wren-ai-service/eval/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ def generate_meta(
"commit": obtain_commit_hash(),
"embedding_model": embedder_provider.get_model(),
"generation_model": llm_provider.get_model(),
"column_indexing_batch_size": int(os.getenv("COLUMN_INDEXING_BATCH_SIZE"))
or 50,
"table_retrieval_size": int(os.getenv("TABLE_RETRIEVAL_SIZE")) or 10,
"table_column_retrieval_size": int(os.getenv("TABLE_COLUMN_RETRIEVAL_SIZE"))
or 1000,
"pipeline": pipe,
"batch_size": os.getenv("BATCH_SIZE") or 4,
"batch_interval": os.getenv("BATCH_INTERVAL") or 1,
Expand Down Expand Up @@ -138,9 +143,19 @@ def parse_args() -> Tuple[str]:
dataset = parse_toml(path)
providers = init_providers(dataset["mdl"])

meta = generate_meta(path=path, dataset=dataset, pipe=pipe_name, **providers)
meta = generate_meta(
path=path,
dataset=dataset,
pipe=pipe_name,
**providers,
)

pipe = pipelines.init(pipe_name, meta, mdl=dataset["mdl"], providers=providers)
pipe = pipelines.init(
pipe_name,
meta,
mdl=dataset["mdl"],
providers=providers,
)

predictions = pipe.predict(dataset["eval_dataset"])
meta["expected_batch_size"] = meta["query_count"] * pipe.candidate_size
Expand Down
3 changes: 3 additions & 0 deletions wren-ai-service/eval/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,9 @@ def trace_metadata(
"dataset_id": meta["dataset_id"],
"embedding_model": meta["embedding_model"],
"generation_model": meta["generation_model"],
"column_indexing_batch_size": meta["column_indexing_batch_size"],
"table_retrieval_size": meta["table_retrieval_size"],
"table_column_retrieval_size": meta["table_column_retrieval_size"],
"type": type,
"pipeline": meta["pipeline"],
}
Expand Down
Loading

0 comments on commit 4e0acc9

Please sign in to comment.