Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(wren-ai-service): retrieval improvement #599

Merged
merged 34 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docker/.env.ai.example
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
LLM_PROVIDER=openai_llm # openai_llm, azure_openai_llm, ollama_llm
GENERATION_MODEL=gpt-4o-mini
GENERATION_MODEL_KWARGS={"temperature": 0, "n": 1, "max_tokens": 4096, "response_format": {"type": "json_object"}}
COLUMN_INDEXING_BATCH_SIZE=50
TABLE_RETRIEVAL_SIZE=10
TABLE_COLUMN_RETRIEVAL_SIZE=1000

# openai or openai-api-compatible
LLM_OPENAI_API_KEY=sk-xxxx
Expand Down
5 changes: 4 additions & 1 deletion wren-ai-service/.env.dev.example
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
WREN_AI_SERVICE_HOST=127.0.0.1
WREN_AI_SERVICE_PORT=5556
SHOULD_FORCE_DEPLOY=
COLUMN_INDEXING_BATCH_SIZE=50
TABLE_RETRIEVAL_SIZE=10
TABLE_COLUMN_RETRIEVAL_SIZE=1000

## LLM
LLM_PROVIDER=openai_llm # openai_llm, azure_openai_llm, ollama_llm
Expand Down Expand Up @@ -64,7 +67,7 @@ WREN_ENGINE_MANIFEST=
# Evaluation
DATASET_NAME=book_2

LANGFUSE_ENABLE=false # true/false
LANGFUSE_ENABLE=
LANGFUSE_SECRET_KEY=
LANGFUSE_PUBLIC_KEY=
LANGFUSE_HOST=https://cloud.langfuse.com
Expand Down
6 changes: 0 additions & 6 deletions wren-ai-service/demo/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
save_mdl_json_file,
show_asks_details_results,
show_asks_results,
show_er_diagram,
update_llm,
)

Expand Down Expand Up @@ -145,11 +144,6 @@ def onchange_llm_model():
expanded=False,
)

show_er_diagram(
st.session_state["mdl_json"]["models"],
st.session_state["mdl_json"]["relationships"],
)

deploy_ok = st.button(
"Deploy",
use_container_width=True,
Expand Down
47 changes: 0 additions & 47 deletions wren-ai-service/demo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import copy
import json
import os
import re
import time
from pathlib import Path
from typing import Dict, List, Optional, Tuple
Expand Down Expand Up @@ -179,52 +178,6 @@ def get_data_from_wren_engine(


# ui related
def show_er_diagram(models: List[dict], relationships: List[dict]):
# Start of the Graphviz syntax
graphviz = "digraph ERD {\n"
graphviz += ' graph [pad="0.5", nodesep="0.5", ranksep="2"];\n'
graphviz += " node [shape=plain]\n"
graphviz += " rankdir=LR;\n\n"

# Function to format the label for Graphviz
def format_label(name, columns):
label = f'<<TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0"><TR><TD><B>{name}</B></TD></TR>'
for column in columns:
label += f'<TR><TD>{column["name"]} : {column["type"]}</TD></TR>'
label += "</TABLE>>"
return label

# Add models (entities) to the Graphviz syntax
for model in models:
graphviz += f' {model["name"]} [label={format_label(model["name"], model["columns"])}];\n'

graphviz += "\n"

# Extract columns involved in each relationship
def extract_columns(condition):
# This regular expression should match the condition format and extract column names
matches = re.findall(r"(\w+)\.(\w+) = (\w+)\.(\w+)", condition)
if matches:
return matches[0][1], matches[0][3] # Returns (from_column, to_column)
return "", ""

# Add relationships to the Graphviz syntax
for relationship in relationships:
from_model, to_model = relationship["models"]
from_column, to_column = extract_columns(relationship["condition"])
label = (
f'{relationship["name"]}\\n({from_column} to {to_column}) ({relationship['joinType']})'
if from_column and to_column
else relationship["name"]
)
graphviz += f' {from_model} -> {to_model} [label="{label}"];\n'

graphviz += "}"

st.markdown("ER Diagram")
st.graphviz_chart(graphviz)


def show_query_history():
if st.session_state["query_history"]:
with st.expander("Query History", expanded=False):
Expand Down
4 changes: 2 additions & 2 deletions wren-ai-service/eval/data_curation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def _build_partial_mdl_json(
"content": ddl_command,
}
for new_mdl_json in new_mdl_jsons
for i, ddl_command in enumerate(ddl_converter.get_ddl_commands(new_mdl_json))
for i, ddl_command in enumerate(ddl_converter._get_ddl_commands(new_mdl_json))
]


Expand Down Expand Up @@ -274,7 +274,7 @@ async def get_question_sql_pairs(
{custom_instructions}

### Input ###
Data Model: {"\n\n".join(ddl_converter.get_ddl_commands(mdl_json))}
Data Model: {"\n\n".join(ddl_converter._get_ddl_commands(mdl_json))}

Generate {num_pairs} of the questions and corresponding SQL queries according to the Output Format in JSON
Think step by step
Expand Down
26 changes: 12 additions & 14 deletions wren-ai-service/eval/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def parse_ddl(ddl: str) -> list:

columns = []
for doc in docs:
columns.extend(parse_ddl(doc["content"]))
columns.extend(parse_ddl(doc))
return columns


Expand Down Expand Up @@ -172,6 +172,7 @@ def __init__(
self,
meta: dict,
mdl: dict,
llm_provider: LLMProvider,
embedder_provider: EmbedderProvider,
document_store_provider: DocumentStoreProvider,
**kwargs,
Expand All @@ -186,17 +187,17 @@ def __init__(
deploy_model(mdl, _indexing)

self._retrieval = retrieval.Retrieval(
llm_provider=llm_provider,
embedder_provider=embedder_provider,
document_store_provider=document_store_provider,
table_retrieval_size=meta["table_retrieval_size"],
table_column_retrieval_size=meta["table_column_retrieval_size"],
)

async def _process(self, prediction: dict, **_) -> dict:
result = await self._retrieval.run(query=prediction["input"])
documents = result.get("retrieval", {}).get("documents", [])

prediction["retrieval_context"] = extract_units(
[doc.to_dict() for doc in documents]
)
documents = result.get("construct_retrieval_results", [])
prediction["retrieval_context"] = extract_units(documents)

return prediction

Expand Down Expand Up @@ -241,17 +242,15 @@ async def _flat(self, prediction: dict, actual: str) -> dict:
return prediction

async def _process(self, prediction: dict, document: list, **_) -> dict:
documents = [Document.from_dict(doc) for doc in document]
documents = [Document.from_dict(doc).content for doc in document]
actual_output = await self._generation.run(
query=prediction["input"],
contexts=documents,
exclude=[],
)

prediction["actual_output"] = actual_output
prediction["retrieval_context"] = extract_units(
[doc.to_dict() for doc in documents]
)
prediction["retrieval_context"] = extract_units(documents)

return prediction

Expand Down Expand Up @@ -302,6 +301,7 @@ def __init__(

self._mdl = mdl
self._retrieval = retrieval.Retrieval(
llm_provider=llm_provider,
embedder_provider=embedder_provider,
document_store_provider=document_store_provider,
)
Expand All @@ -319,17 +319,15 @@ async def _flat(self, prediction: dict, actual: str) -> dict:

async def _process(self, prediction: dict, **_) -> dict:
result = await self._retrieval.run(query=prediction["input"])
documents = result.get("retrieval", {}).get("documents", [])
documents = result.get("construct_retrieval_results", [])
actual_output = await self._generation.run(
query=prediction["input"],
contexts=documents,
exclude=[],
)

prediction["actual_output"] = actual_output
prediction["retrieval_context"] = extract_units(
[doc.to_dict() for doc in documents]
)
prediction["retrieval_context"] = extract_units(documents)

return prediction

Expand Down
19 changes: 17 additions & 2 deletions wren-ai-service/eval/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ def generate_meta(
"commit": obtain_commit_hash(),
"embedding_model": embedder_provider.get_model(),
"generation_model": llm_provider.get_model(),
"column_indexing_batch_size": int(os.getenv("COLUMN_INDEXING_BATCH_SIZE"))
or 50,
"table_retrieval_size": int(os.getenv("TABLE_RETRIEVAL_SIZE")) or 10,
"table_column_retrieval_size": int(os.getenv("TABLE_COLUMN_RETRIEVAL_SIZE"))
or 1000,
"pipeline": pipe,
"batch_size": os.getenv("BATCH_SIZE") or 4,
"batch_interval": os.getenv("BATCH_INTERVAL") or 1,
Expand Down Expand Up @@ -138,9 +143,19 @@ def parse_args() -> Tuple[str]:
dataset = parse_toml(path)
providers = init_providers(dataset["mdl"])

meta = generate_meta(path=path, dataset=dataset, pipe=pipe_name, **providers)
meta = generate_meta(
path=path,
dataset=dataset,
pipe=pipe_name,
**providers,
)

pipe = pipelines.init(pipe_name, meta, mdl=dataset["mdl"], providers=providers)
pipe = pipelines.init(
pipe_name,
meta,
mdl=dataset["mdl"],
providers=providers,
)

predictions = pipe.predict(dataset["eval_dataset"])
meta["expected_batch_size"] = meta["query_count"] * pipe.candidate_size
Expand Down
3 changes: 3 additions & 0 deletions wren-ai-service/eval/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,9 @@ def trace_metadata(
"dataset_id": meta["dataset_id"],
"embedding_model": meta["embedding_model"],
"generation_model": meta["generation_model"],
"column_indexing_batch_size": meta["column_indexing_batch_size"],
"table_retrieval_size": meta["table_retrieval_size"],
"table_column_retrieval_size": meta["table_column_retrieval_size"],
"type": type,
"pipeline": meta["pipeline"],
}
Expand Down
Loading
Loading