diff --git a/wren-ai-service/Justfile b/wren-ai-service/Justfile index a5c654044f..7b1335712b 100644 --- a/wren-ai-service/Justfile +++ b/wren-ai-service/Justfile @@ -47,8 +47,8 @@ demo: test test_args='': up && down poetry run pytest -s {{test_args}} --ignore tests/pytest/test_usecases.py -test-usecases usecases='all': - poetry run python -m tests.pytest.test_usecases --usecases {{usecases}} +test-usecases usecases='all' lang='en': + poetry run python -m tests.pytest.test_usecases --usecases {{usecases}} --lang {{lang}} load-test: poetry run python -m tests.locust.locust_script diff --git a/wren-ai-service/src/pipelines/retrieval/historical_question_retrieval.py b/wren-ai-service/src/pipelines/retrieval/historical_question_retrieval.py index 5780bb1e17..2586b767cf 100644 --- a/wren-ai-service/src/pipelines/retrieval/historical_question_retrieval.py +++ b/wren-ai-service/src/pipelines/retrieval/historical_question_retrieval.py @@ -22,16 +22,16 @@ class OutputFormatter: documents=List[Optional[Dict]], ) def run(self, documents: List[Document]): - list = [] - - for doc in documents: - formatted = { + list = [ + { "question": doc.content, "summary": doc.meta.get("summary", ""), "statement": doc.meta.get("statement") or doc.meta.get("sql"), "viewId": doc.meta.get("viewId", ""), + "sqlpairId": doc.meta.get("sql_pair_id", ""), } - list.append(formatted) + for doc in documents + ] return {"documents": list} diff --git a/wren-ai-service/src/web/v1/services/ask.py b/wren-ai-service/src/web/v1/services/ask.py index f0a5e655d9..8d339dac2e 100644 --- a/wren-ai-service/src/web/v1/services/ask.py +++ b/wren-ai-service/src/web/v1/services/ask.py @@ -65,8 +65,9 @@ class StopAskResponse(BaseModel): # GET /v1/asks/{query_id}/result class AskResult(BaseModel): sql: str - type: Literal["llm", "view"] = "llm" + type: Literal["llm", "view", "sql_pair"] = "llm" viewId: Optional[str] = None + sqlpairId: Optional[str] = None class AskError(BaseModel): @@ -243,8 +244,13 @@ async def ask( AskResult( **{ "sql": result.get("statement"), - "type": "view", + "type": "view" + if result.get("viewId") + else "sql_pair" + if result.get("sqlpairId") + else "llm", "viewId": result.get("viewId"), + "sqlpairId": result.get("sqlpairId"), } ) for result in historical_question_result diff --git a/wren-ai-service/tests/pytest/test_usecases.py b/wren-ai-service/tests/pytest/test_usecases.py index 03a9e78f69..c48633e21f 100644 --- a/wren-ai-service/tests/pytest/test_usecases.py +++ b/wren-ai-service/tests/pytest/test_usecases.py @@ -107,12 +107,19 @@ def deploy_mdl(mdl_str: str, url: str): return semantics_preperation_id -async def ask_question(question: str, url: str, semantics_preperation_id: str): +async def ask_question( + question: str, url: str, semantics_preperation_id: str, lang: str = "English" +): print(f"preparing to ask question: {question}") async with aiohttp.ClientSession() as session: start = time.time() response = await session.post( - f"{url}/v1/asks", json={"query": question, "id": semantics_preperation_id} + f"{url}/v1/asks", + json={ + "query": question, + "id": semantics_preperation_id, + "configurations": {"language": lang}, + }, ) assert response.status == 200 @@ -133,11 +140,13 @@ async def ask_question(question: str, url: str, semantics_preperation_id: str): return result -async def ask_questions(questions: list[str], url: str, semantics_preperation_id: str): +async def ask_questions( + questions: list[str], url: str, semantics_preperation_id: str, lang: str = "English" +): tasks = [] for question in questions: task = asyncio.ensure_future( - ask_question(question, url, semantics_preperation_id) + ask_question(question, url, semantics_preperation_id, lang) ) tasks.append(task) await asyncio.sleep(10) @@ -160,7 +169,7 @@ def str_presenter(dumper, data): "woocommerce": "bigquery", "stripe": "bigquery", "ecommerce": "duckdb", - "hr": "duckdb", + # "hr": "duckdb", "facebook_marketing": "bigquery", "google_ads": "bigquery", } @@ -174,11 +183,23 @@ def str_presenter(dumper, data): default=["all"], choices=["all"] + usecases, ) + parser.add_argument( + "--lang", + type=str, + choices=["en", "zh-TW", "zh-CN"], + default="en", + ) args = parser.parse_args() if "all" not in args.usecases: usecases = args.usecases + lang = { + "en": "English", + "zh-TW": "Traditional Chinese", + "zh-CN": "Simplified Chinese", + }[args.lang] + url = "http://localhost:5556" assert is_ai_service_ready( @@ -197,7 +218,7 @@ def str_presenter(dumper, data): # ask questions results = asyncio.run( - ask_questions(data["questions"], url, semantics_preperation_id) + ask_questions(data["questions"], url, semantics_preperation_id, lang) ) assert len(results) == len(data["questions"])