Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions wren-ai-service/src/pipelines/common.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
from typing import Any, List, Optional, Tuple

from haystack import Document, component
Expand Down Expand Up @@ -102,3 +103,10 @@ def run(
reverse=True,
)[:max_size]
}


MULTIPLE_NEW_LINE_REGEX = re.compile(r"\n{3,}")


def clean_up_new_lines(text: str) -> str:
return MULTIPLE_NEW_LINE_REGEX.sub("\n\n\n", text)
4 changes: 3 additions & 1 deletion wren-ai-service/src/pipelines/generation/chart_adjustment.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from src.core.pipeline import BasicPipeline
from src.core.provider import LLMProvider
from src.pipelines.common import clean_up_new_lines
from src.pipelines.generation.utils.chart import (
ChartDataPreprocessor,
ChartGenerationPostProcessor,
Expand Down Expand Up @@ -97,7 +98,7 @@ def prompt(
sample_data = preprocess_data.get("sample_data")
sample_column_values = preprocess_data.get("sample_column_values")

return prompt_builder.run(
_prompt = prompt_builder.run(
query=query,
sql=sql,
adjustment_option=adjustment_option,
Expand All @@ -106,6 +107,7 @@ def prompt(
sample_column_values=sample_column_values,
language=language,
)
return {"prompt": clean_up_new_lines(_prompt.get("prompt"))}


@observe(as_type="generation", capture_input=False)
Expand Down
4 changes: 3 additions & 1 deletion wren-ai-service/src/pipelines/generation/chart_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from src.core.pipeline import BasicPipeline
from src.core.provider import LLMProvider
from src.pipelines.common import clean_up_new_lines
from src.pipelines.generation.utils.chart import (
ChartDataPreprocessor,
ChartGenerationPostProcessor,
Expand Down Expand Up @@ -73,14 +74,15 @@ def prompt(
sample_data = preprocess_data.get("sample_data")
sample_column_values = preprocess_data.get("sample_column_values")

return prompt_builder.run(
_prompt = prompt_builder.run(
query=query,
sql=sql,
sample_data=sample_data,
sample_column_values=sample_column_values,
language=language,
custom_instruction=custom_instruction,
)
return {"prompt": clean_up_new_lines(_prompt.get("prompt"))}


@observe(as_type="generation", capture_input=False)
Expand Down
4 changes: 3 additions & 1 deletion wren-ai-service/src/pipelines/generation/data_assistance.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from src.core.pipeline import BasicPipeline
from src.core.provider import LLMProvider
from src.pipelines.common import clean_up_new_lines
from src.utils import trace_cost
from src.web.v1.services.ask import AskHistory

Expand Down Expand Up @@ -61,11 +62,12 @@ def prompt(
)
query = "\n".join(previous_query_summaries) + "\n" + query

return prompt_builder.run(
_prompt = prompt_builder.run(
query=query,
db_schemas=db_schemas,
language=language,
)
return {"prompt": clean_up_new_lines(_prompt.get("prompt"))}


@observe(as_type="generation", capture_input=False)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from src.core.engine import Engine
from src.core.pipeline import BasicPipeline
from src.core.provider import DocumentStoreProvider, LLMProvider
from src.pipelines.common import retrieve_metadata
from src.pipelines.common import clean_up_new_lines, retrieve_metadata
from src.pipelines.generation.utils.sql import (
SQL_GENERATION_MODEL_KWARGS,
SQLGenPostProcessor,
Expand Down Expand Up @@ -98,7 +98,7 @@ def prompt(
has_json_field: bool = False,
sql_functions: list[SqlFunction] | None = None,
) -> dict:
return prompt_builder.run(
_prompt = prompt_builder.run(
query=query,
documents=documents,
sql_generation_reasoning=sql_generation_reasoning,
Expand All @@ -113,6 +113,7 @@ def prompt(
sql_samples=sql_samples,
sql_functions=sql_functions,
)
return {"prompt": clean_up_new_lines(_prompt.get("prompt"))}


@observe(as_type="generation", capture_input=False)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from src.core.pipeline import BasicPipeline
from src.core.provider import LLMProvider
from src.pipelines.common import clean_up_new_lines
from src.pipelines.generation.utils.sql import (
construct_instructions,
sql_generation_reasoning_system_prompt,
Expand Down Expand Up @@ -71,7 +72,7 @@ def prompt(
prompt_builder: PromptBuilder,
configuration: Configuration | None = Configuration(),
) -> dict:
return prompt_builder.run(
_prompt = prompt_builder.run(
query=query,
documents=documents,
histories=histories,
Expand All @@ -81,6 +82,7 @@ def prompt(
),
language=configuration.language,
)
return {"prompt": clean_up_new_lines(_prompt.get("prompt"))}


@observe(as_type="generation", capture_input=False)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from src.core.pipeline import BasicPipeline
from src.core.provider import DocumentStoreProvider, EmbedderProvider, LLMProvider
from src.pipelines.common import build_table_ddl
from src.pipelines.common import build_table_ddl, clean_up_new_lines
from src.pipelines.generation.utils.sql import construct_instructions
from src.utils import trace_cost
from src.web.v1.services import Configuration
Expand Down Expand Up @@ -276,7 +276,7 @@ def prompt(
instructions: Optional[list[dict]] = None,
configuration: Configuration | None = None,
) -> dict:
return prompt_builder.run(
_prompt = prompt_builder.run(
query=query,
language=configuration.language,
db_schemas=construct_db_schemas,
Expand All @@ -287,6 +287,7 @@ def prompt(
),
docs=wren_ai_docs,
)
return {"prompt": clean_up_new_lines(_prompt.get("prompt"))}


@observe(as_type="generation", capture_input=False)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from src.core.pipeline import BasicPipeline
from src.core.provider import LLMProvider
from src.pipelines.common import clean_up_new_lines
from src.utils import trace_cost
from src.web.v1.services.ask import AskHistory

Expand Down Expand Up @@ -61,11 +62,12 @@ def prompt(
)
query = "\n".join(previous_query_summaries) + "\n" + query

return prompt_builder.run(
_prompt = prompt_builder.run(
query=query,
db_schemas=db_schemas,
language=language,
)
return {"prompt": clean_up_new_lines(_prompt.get("prompt"))}


@observe(as_type="generation", capture_input=False)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from src.core.pipeline import BasicPipeline
from src.core.provider import LLMProvider
from src.pipelines.common import clean_up_new_lines
from src.utils import trace_cost

logger = logging.getLogger("wren-ai-service")
Expand All @@ -32,13 +33,14 @@ def prompt(
contextually relevant questions that build on previous questions.
"""

return prompt_builder.run(
_prompt = prompt_builder.run(
models=[] if previous_questions else mdl.get("models", []),
previous_questions=previous_questions,
language=language,
max_questions=max_questions,
max_categories=max_categories,
)
return {"prompt": clean_up_new_lines(_prompt.get("prompt"))}


@observe(as_type="generation", capture_input=False)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from src.core.engine import Engine
from src.core.pipeline import BasicPipeline
from src.core.provider import LLMProvider
from src.pipelines.common import clean_up_new_lines
from src.utils import trace_cost

logger = logging.getLogger("wren-ai-service")
Expand Down Expand Up @@ -51,7 +52,8 @@ def prompt(
prompt_builder: PromptBuilder,
language: str,
) -> dict:
return prompt_builder.run(models=cleaned_models, language=language)
_prompt = prompt_builder.run(models=cleaned_models, language=language)
return {"prompt": clean_up_new_lines(_prompt.get("prompt"))}


@observe(as_type="generation", capture_input=False)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from src.core.pipeline import BasicPipeline
from src.core.provider import LLMProvider
from src.pipelines.common import clean_up_new_lines
from src.utils import trace_cost

logger = logging.getLogger("wren-ai-service")
Expand Down Expand Up @@ -60,11 +61,12 @@ def prompt(
prompt_builder: PromptBuilder,
language: str,
) -> dict:
return prompt_builder.run(
_prompt = prompt_builder.run(
picked_models=picked_models,
user_prompt=user_prompt,
language=language,
)
return {"prompt": clean_up_new_lines(_prompt.get("prompt"))}


@observe(as_type="generation", capture_input=False)
Expand Down
15 changes: 13 additions & 2 deletions wren-ai-service/src/pipelines/generation/sql_answer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@

from src.core.pipeline import BasicPipeline
from src.core.provider import LLMProvider
from src.pipelines.common import clean_up_new_lines
from src.utils import trace_cost
from src.web.v1.services import Configuration

logger = logging.getLogger("wren-ai-service")

Expand Down Expand Up @@ -40,8 +42,12 @@
### Input
User's question: {{ query }}
SQL: {{ sql }}
Data: {{ sql_data }}
Data:
columns: {{ sql_data.columns }}
rows: {{ sql_data.data }}
Language: {{ language }}
Current Time: {{ current_time }}

Custom Instruction: {{ custom_instruction }}

Please think step by step and answer the user's question.
Expand All @@ -55,16 +61,19 @@ def prompt(
sql: str,
sql_data: dict,
language: str,
current_time: str,
custom_instruction: str,
prompt_builder: PromptBuilder,
) -> dict:
return prompt_builder.run(
_prompt = prompt_builder.run(
query=query,
sql=sql,
sql_data=sql_data,
language=language,
current_time=current_time,
custom_instruction=custom_instruction,
)
return {"prompt": clean_up_new_lines(_prompt.get("prompt"))}


@observe(as_type="generation", capture_input=False)
Expand Down Expand Up @@ -144,6 +153,7 @@ async def run(
sql: str,
sql_data: dict,
language: str,
current_time: str = Configuration().show_current_time(),
query_id: Optional[str] = None,
custom_instruction: Optional[str] = None,
) -> dict:
Expand All @@ -155,6 +165,7 @@ async def run(
"sql": sql,
"sql_data": sql_data,
"language": language,
"current_time": current_time,
"query_id": query_id,
"custom_instruction": custom_instruction or "",
**self._components,
Expand Down
5 changes: 3 additions & 2 deletions wren-ai-service/src/pipelines/generation/sql_correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from src.core.engine import Engine
from src.core.pipeline import BasicPipeline
from src.core.provider import DocumentStoreProvider, LLMProvider
from src.pipelines.common import retrieve_metadata
from src.pipelines.common import clean_up_new_lines, retrieve_metadata
from src.pipelines.generation.utils.sql import (
SQL_GENERATION_MODEL_KWARGS,
TEXT_TO_SQL_RULES,
Expand Down Expand Up @@ -66,10 +66,11 @@ def prompt(
invalid_generation_result: Dict,
prompt_builder: PromptBuilder,
) -> dict:
return prompt_builder.run(
_prompt = prompt_builder.run(
documents=documents,
invalid_generation_result=invalid_generation_result,
)
return {"prompt": clean_up_new_lines(_prompt.get("prompt"))}


@observe(as_type="generation", capture_input=False)
Expand Down
5 changes: 3 additions & 2 deletions wren-ai-service/src/pipelines/generation/sql_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from src.core.engine import Engine
from src.core.pipeline import BasicPipeline
from src.core.provider import DocumentStoreProvider, LLMProvider
from src.pipelines.common import retrieve_metadata
from src.pipelines.common import clean_up_new_lines, retrieve_metadata
from src.pipelines.generation.utils.sql import (
SQL_GENERATION_MODEL_KWARGS,
SQLGenPostProcessor,
Expand Down Expand Up @@ -94,7 +94,7 @@ def prompt(
has_json_field: bool = False,
sql_functions: list[SqlFunction] | None = None,
) -> dict:
return prompt_builder.run(
_prompt = prompt_builder.run(
query=query,
documents=documents,
sql_generation_reasoning=sql_generation_reasoning,
Expand All @@ -109,6 +109,7 @@ def prompt(
sql_samples=sql_samples,
sql_functions=sql_functions,
)
return {"prompt": clean_up_new_lines(_prompt.get("prompt"))}


@observe(as_type="generation", capture_input=False)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from src.core.pipeline import BasicPipeline
from src.core.provider import LLMProvider
from src.pipelines.common import clean_up_new_lines
from src.pipelines.generation.utils.sql import (
construct_instructions,
sql_generation_reasoning_system_prompt,
Expand Down Expand Up @@ -61,7 +62,7 @@ def prompt(
prompt_builder: PromptBuilder,
configuration: Configuration | None = Configuration(),
) -> dict:
return prompt_builder.run(
_prompt = prompt_builder.run(
query=query,
documents=documents,
sql_samples=sql_samples,
Expand All @@ -70,6 +71,7 @@ def prompt(
),
language=configuration.language,
)
return {"prompt": clean_up_new_lines(_prompt.get("prompt"))}


@observe(as_type="generation", capture_input=False)
Expand Down
Loading
Loading