Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions wren/src/wren/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,34 @@ def _build_engine(
OutputOpt = Annotated[
str, typer.Option("--output", "-o", help="Output format: json|csv|table")
]
QuietOpt = Annotated[
bool,
typer.Option(
"--quiet",
"-q",
help="Suppress informational tips (e.g. store hints after query).",
),
]


def _print_store_tip(sql: str) -> None:
"""Print a memory store hint to stderr."""
escaped = sql.replace("'", "'\\''")
typer.echo(
f"\n# To save this query:\n"
f"# wren memory store --nl '<natural language question>' "
f"--sql '{escaped}'",
err=True,
)


def _maybe_print_store_tip(sql: str, quiet: bool) -> None:
if quiet:
return
from wren.sql_classify import is_exploratory # noqa: PLC0415

if not is_exploratory(sql):
_print_store_tip(sql)


# ── Default command (no subcommand = query) ────────────────────────────────
Expand All @@ -201,6 +229,7 @@ def main(
connection_file: ConnFileOpt = None,
limit: LimitOpt = None,
output: OutputOpt = "table",
quiet: QuietOpt = False,
) -> None:
"""Wren Engine CLI.

Expand Down Expand Up @@ -243,6 +272,7 @@ def main(
typer.echo(f"Error: {e}", err=True)
raise typer.Exit(1)
_print_result(result, output)
_maybe_print_store_tip(sql, quiet)


# ── Subcommands ────────────────────────────────────────────────────────────
Expand All @@ -256,6 +286,7 @@ def query(
connection_file: ConnFileOpt = None,
limit: LimitOpt = None,
output: OutputOpt = "table",
quiet: QuietOpt = False,
):
"""Execute a SQL query through the Wren semantic layer."""
with _build_engine(mdl, connection_info, connection_file) as engine:
Expand All @@ -265,6 +296,7 @@ def query(
typer.echo(f"Error: {e}", err=True)
raise typer.Exit(1)
_print_result(result, output)
_maybe_print_store_tip(sql, quiet)


@app.command(name="dry-run")
Expand Down
41 changes: 41 additions & 0 deletions wren/src/wren/sql_classify.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Lightweight SQL classification for store-tip heuristics."""

import sqlglot
from sqlglot import exp


def is_exploratory(sql: str) -> bool:
"""Return True if *sql* looks like an exploratory peek query.

Exploratory = bare SELECT with no WHERE/GROUP BY/HAVING/aggregate.
A LIMIT is neither required nor disqualifying.
Top-level clauses are inspected via stmt.args to avoid descending into
subqueries (e.g. an inner LIMIT does not affect the outer query).
"""
try:
parsed = sqlglot.parse(sql)
except sqlglot.errors.ParseError:
return False # can't parse → don't suppress tip

if len(parsed) != 1 or not isinstance(parsed[0], exp.Select):
return False # multi-statement or non-SELECT → not exploratory

stmt = parsed[0]

# A CTE-backed SELECT stores the With clause under the "with_" key
if stmt.args.get("with_") is not None:
return False

# Check only the top-level clause args — does not descend into subqueries
has_where = stmt.args.get("where") is not None
has_group = stmt.args.get("group") is not None
has_having = stmt.args.get("having") is not None

# find() descends into children but that's fine for aggregates: any
# aggregate anywhere in the tree indicates non-trivial processing
has_agg = stmt.find(exp.AggFunc) is not None

if has_where or has_group or has_having or has_agg:
return False # analytical query → not exploratory

return True # bare SELECT (with or without LIMIT) = exploratory peek
130 changes: 130 additions & 0 deletions wren/tests/unit/test_cli_store_tip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
"""Unit tests for the store-tip behavior in the CLI.

Uses typer's CliRunner with a mocked WrenEngine so no real database is needed.
"""

from __future__ import annotations

from contextlib import contextmanager
from unittest.mock import MagicMock, patch

import pyarrow as pa
import pytest
from typer.testing import CliRunner

from wren.cli import app

pytestmark = pytest.mark.unit

runner = CliRunner()

# Minimal pyarrow table for mock query results
_MOCK_TABLE = pa.table({"id": [1, 2, 3]})

_CONN_FILE_ARGS = ["--connection-file", "/dev/null"]
_MDL_ARGS = ["--mdl", "/dev/null"]


@contextmanager
def _mock_engine(result=None):
"""Context manager that patches _build_engine to return a fake engine."""
mock_engine = MagicMock()
mock_engine.__enter__ = MagicMock(return_value=mock_engine)
mock_engine.__exit__ = MagicMock(return_value=False)
mock_engine.query.return_value = _MOCK_TABLE if result is None else result

with patch("wren.cli._build_engine", return_value=mock_engine):
yield mock_engine


# ── main() (default callback) ──────────────────────────────────────────────


def test_main_tip_appears_on_stderr_for_analytical_query():
sql = "SELECT status, COUNT(*) FROM orders GROUP BY 1"
with _mock_engine():
result = runner.invoke(app, ["--sql", sql] + _MDL_ARGS + _CONN_FILE_ARGS)
assert result.exit_code == 0
assert "wren memory store" in result.stderr
assert "wren memory store" not in result.stdout


def test_main_tip_suppressed_with_quiet_flag():
sql = "SELECT status, COUNT(*) FROM orders GROUP BY 1"
with _mock_engine():
result = runner.invoke(
app, ["--sql", sql, "--quiet"] + _MDL_ARGS + _CONN_FILE_ARGS
)
assert result.exit_code == 0
assert "wren memory store" not in result.stderr


def test_main_tip_suppressed_for_exploratory_query():
sql = "SELECT * FROM orders LIMIT 5"
with _mock_engine():
result = runner.invoke(app, ["--sql", sql] + _MDL_ARGS + _CONN_FILE_ARGS)
assert result.exit_code == 0
assert "wren memory store" not in result.stderr


def test_main_stdout_clean_of_tip():
sql = "SELECT status, COUNT(*) FROM orders GROUP BY 1"
with _mock_engine():
result = runner.invoke(app, ["--sql", sql] + _MDL_ARGS + _CONN_FILE_ARGS)
assert result.exit_code == 0
assert "wren memory store" not in result.stdout


# ── query subcommand ───────────────────────────────────────────────────────


def test_query_tip_appears_for_analytical_query():
sql = "SELECT * FROM orders WHERE total > 100"
with _mock_engine():
result = runner.invoke(
app, ["query", "--sql", sql] + _MDL_ARGS + _CONN_FILE_ARGS
)
assert result.exit_code == 0
assert "wren memory store" in result.stderr


def test_query_tip_suppressed_with_quiet_flag():
sql = "SELECT * FROM orders WHERE total > 100"
with _mock_engine():
result = runner.invoke(
app,
["query", "--sql", sql, "--quiet"] + _MDL_ARGS + _CONN_FILE_ARGS,
)
assert result.exit_code == 0
assert "wren memory store" not in result.stderr


def test_query_tip_suppressed_for_exploratory_query():
sql = "SELECT * FROM orders"
with _mock_engine():
result = runner.invoke(
app, ["query", "--sql", sql] + _MDL_ARGS + _CONN_FILE_ARGS
)
assert result.exit_code == 0
assert "wren memory store" not in result.stderr


def test_tip_includes_sql_in_suggested_command():
sql = "SELECT COUNT(*) FROM orders"
with _mock_engine():
result = runner.invoke(app, ["--sql", sql] + _MDL_ARGS + _CONN_FILE_ARGS)
assert result.exit_code == 0
assert sql in result.stderr


def test_tip_escapes_single_quotes_in_sql():
sql = "SELECT * FROM orders WHERE status = 'open'"
with _mock_engine():
result = runner.invoke(app, ["--sql", sql] + _MDL_ARGS + _CONN_FILE_ARGS)
assert result.exit_code == 0
assert "wren memory store" in result.stderr
# Correctly escaped form for POSIX shell single-quote wrapping
expected = "--sql 'SELECT * FROM orders WHERE status = '\\''open'\\'''"
assert expected in result.stderr
# Unescaped form would break shell — must not appear
assert "--sql 'SELECT * FROM orders WHERE status = 'open''" not in result.stderr
49 changes: 49 additions & 0 deletions wren/tests/unit/test_sql_classify.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""Unit tests for sql_classify.is_exploratory()."""

import pytest

from wren.sql_classify import is_exploratory

pytestmark = pytest.mark.unit


@pytest.mark.parametrize(
"sql, expected",
[
# Exploratory: bare SELECT, no WHERE/GROUP/HAVING/agg (LIMIT optional)
("SELECT * FROM orders LIMIT 5", True),
("SELECT DISTINCT status FROM orders LIMIT 10", True),
("SELECT * FROM orders", True), # no LIMIT is still exploratory
# Aggregate present
("SELECT status, COUNT(*) FROM orders GROUP BY 1", False),
# WHERE present
("SELECT * FROM orders WHERE total > 100 LIMIT 10", False),
# UNION — not a bare SELECT
("SELECT a FROM x UNION SELECT b FROM y", False),
# CTE — not exploratory
("WITH cte AS (SELECT 1) SELECT * FROM cte LIMIT 1", False),
# GROUP BY without aggregate
("SELECT status FROM orders GROUP BY status LIMIT 5", False),
# HAVING
(
"SELECT status, COUNT(*) FROM orders GROUP BY status HAVING COUNT(*) > 1",
False,
),
# Aggregate without GROUP BY (scalar aggregate)
("SELECT COUNT(*) FROM orders", False),
# SUM
("SELECT SUM(total) FROM orders", False),
# Inner LIMIT only — outer SELECT has no conditions, still exploratory
("SELECT * FROM (SELECT * FROM orders LIMIT 5) t", True),
],
)
def test_is_exploratory(sql, expected):
assert is_exploratory(sql) is expected


def test_unparseable_sql_returns_false():
assert is_exploratory("NOT VALID SQL $$$$") is False


def test_empty_string_returns_false():
assert is_exploratory("") is False
Loading