Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions wren/src/wren/memory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,20 @@ def __init__(self, path: str | Path | None = None):

self._store = MemoryStore(path=path)

def index_manifest(self, manifest: dict, *, replace: bool = True) -> int:
"""Index MDL schema into LanceDB. Returns record count."""
return self._store.index_schema(manifest, replace=replace)
def index_manifest(
self,
manifest: dict,
*,
replace: bool = True,
seed_queries: bool = True,
) -> dict:
"""Index MDL schema into LanceDB.

Returns {"schema_items": int, "seed_queries": int}.
"""
return self._store.index_schema(
manifest, replace=replace, seed_queries=seed_queries
)

@staticmethod
def describe_schema(manifest: dict) -> str:
Expand Down
14 changes: 11 additions & 3 deletions wren/src/wren/memory/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,20 @@ def _print_results(results: list[dict], output: str) -> None:
def index(
mdl: MdlOpt = None,
path: PathOpt = None,
no_seed: Annotated[
bool,
typer.Option("--no-seed", help="Skip generating seed NL-SQL examples."),
] = False,
) -> None:
"""Index MDL schema into LanceDB for semantic search."""
"""Index MDL schema into LanceDB (and optionally seed example queries)."""
manifest = _load_manifest(mdl)
store = _get_store(path)
count = store.index_schema(manifest)
typer.echo(f"Indexed {count} schema items.")
result = store.index_schema(manifest, seed_queries=not no_seed)
typer.echo(
f"Indexed {result['schema_items']} schema items"
+ (f", {result['seed_queries']} seed queries" if result["seed_queries"] else "")
+ "."
)


@memory_app.command()
Expand Down
109 changes: 109 additions & 0 deletions wren/src/wren/memory/seed_queries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
"""Generate canonical NL-SQL seed pairs from an MDL manifest.

Pure functions — no LanceDB or embedding dependency.
"""

from __future__ import annotations

_NUMERIC_TYPES = {
"int",
"integer",
"bigint",
"smallint",
"tinyint",
"float",
"double",
"decimal",
"numeric",
"real",
"number",
}

SEED_TAG = "source:seed"


def generate_seed_queries(manifest: dict) -> list[dict]:
"""Return a list of {"nl": ..., "sql": ...} seed pairs."""
pairs: list[dict] = []

for model in manifest.get("models", []):
pairs.extend(_model_seeds(model))

for rel in manifest.get("relationships", []):
pair = _relationship_seed(rel)
if pair:
pairs.append(pair)

return pairs


def _model_seeds(model: dict) -> list[dict]:
name = model["name"]
columns = model.get("columns", [])
pairs = []

# Template 1: basic listing
pairs.append(
{
"nl": f"List all {name}",
"sql": f"SELECT * FROM {name} LIMIT 100",
}
)

# Find first numeric column (non-calculated) and first groupable column
numeric_col = None
group_col = None
for col in columns:
col_type = (col.get("type") or "").split("(")[0].lower().strip()
is_calc = col.get("isCalculated", False)
is_pk = col["name"] == model.get("primaryKey")

if (
col_type in _NUMERIC_TYPES
and not is_calc
and not is_pk
and numeric_col is None
):
numeric_col = col["name"]
elif (
col_type not in _NUMERIC_TYPES
and not is_pk
and not is_calc
and group_col is None
):
group_col = col["name"]

# Template 2a: simple aggregation
if numeric_col:
pairs.append(
{
"nl": f"Total {numeric_col} in {name}",
"sql": f"SELECT SUM({numeric_col}) FROM {name}",
}
)

# Template 2b: grouped aggregation
if numeric_col and group_col:
pairs.append(
{
"nl": f"{numeric_col} by {group_col} in {name}",
"sql": f"SELECT {group_col}, SUM({numeric_col}) FROM {name} GROUP BY 1",
}
)

return pairs


def _relationship_seed(rel: dict) -> dict | None:
models = rel.get("models", [])
condition = rel.get("condition", "").strip()

if len(models) < 2 or not condition:
return None

left, right = models[0], models[1]

return {
"nl": f"{left} with {right} details",
"sql": f"SELECT * FROM {left} JOIN {right} ON {condition} LIMIT 100",
}
85 changes: 63 additions & 22 deletions wren/src/wren/memory/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,45 +100,86 @@ def _query_table_schema(self) -> pa.Schema:

# ── Schema indexing ───────────────────────────────────────────────────

def index_schema(self, manifest: dict, *, replace: bool = True) -> int:
def index_schema(
self,
manifest: dict,
*,
replace: bool = True,
seed_queries: bool = True,
) -> dict:
"""Extract schema items from *manifest*, embed, and store.

Returns the number of records indexed.
If *seed_queries* is True, also generates canonical NL-SQL pairs
and inserts them into query_history (tagged 'source:seed').
Old seed entries are replaced; user-confirmed entries are preserved.

Returns {"schema_items": int, "seed_queries": int}.
"""
items = extract_schema_items(manifest)
table_exists = _SCHEMA_TABLE in _table_names(self._db)

if not items:
if replace and table_exists:
self._db.drop_table(_SCHEMA_TABLE)
return 0

texts = [item["text"] for item in items]
vectors = self._embed_fn.compute_source_embeddings(texts)
schema_count = 0
else:
texts = [item["text"] for item in items]
vectors = self._embed_fn.compute_source_embeddings(texts)

for item, vec in zip(items, vectors):
item["vector"] = vec
for item, vec in zip(items, vectors):
item["vector"] = vec

if replace:
if table_exists:
self._db.drop_table(_SCHEMA_TABLE)
self._db.create_table(
_SCHEMA_TABLE,
items,
schema=self._schema_table_schema(),
)
else:
if table_exists:
tbl = self._db.open_table(_SCHEMA_TABLE)
tbl.add(items)
else:
if replace:
if table_exists:
self._db.drop_table(_SCHEMA_TABLE)
self._db.create_table(
_SCHEMA_TABLE,
items,
schema=self._schema_table_schema(),
)
else:
if table_exists:
tbl = self._db.open_table(_SCHEMA_TABLE)
tbl.add(items)
else:
self._db.create_table(
_SCHEMA_TABLE,
items,
schema=self._schema_table_schema(),
)
schema_count = len(items)

seed_count = 0
if seed_queries:
seed_count = self._upsert_seed_queries(manifest)

return {"schema_items": schema_count, "seed_queries": seed_count}

def _upsert_seed_queries(self, manifest: dict) -> int:
"""Replace seed query entries, preserving user-confirmed ones."""
from wren.memory.seed_queries import ( # noqa: PLC0415
SEED_TAG,
generate_seed_queries,
)

# Remove old seeds (tagged 'source:seed') but keep user entries
if _QUERY_TABLE in _table_names(self._db):
table = self._db.open_table(_QUERY_TABLE)
table.delete(f"tags = '{SEED_TAG}'")

pairs = generate_seed_queries(manifest)
if not pairs:
return 0

# Insert new seeds via the existing store_query() method
for pair in pairs:
self.store_query(
nl_query=pair["nl"],
sql_query=pair["sql"],
tags=SEED_TAG,
)

return len(items)
return len(pairs)

def schema_is_current(self, manifest: dict) -> bool:
"""Check whether the indexed schema matches *manifest*.
Expand Down
96 changes: 86 additions & 10 deletions wren/tests/unit/test_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,13 +236,13 @@ def memory_store(tmp_path):
@pytest.mark.unit
class TestMemoryStore:
def test_index_and_context(self, memory_store):
count = memory_store.index_schema(_MANIFEST)
assert count == 10
result = memory_store.index_schema(_MANIFEST)
assert result["schema_items"] == 10

# Small schema → full strategy
result = memory_store.get_context(_MANIFEST, "customer order price")
assert result["strategy"] == "full"
assert "### Model: orders" in result["schema"]
ctx = memory_store.get_context(_MANIFEST, "customer order price")
assert ctx["strategy"] == "full"
assert "### Model: orders" in ctx["schema"]

def test_context_search_strategy(self, memory_store):
memory_store.index_schema(_MANIFEST)
Expand Down Expand Up @@ -316,8 +316,8 @@ def test_describe_schema_static(self, memory_store):

def test_index_replace(self, memory_store):
memory_store.index_schema(_MANIFEST)
count = memory_store.index_schema(_MANIFEST, replace=True)
assert count == 10
result = memory_store.index_schema(_MANIFEST, replace=True)
assert result["schema_items"] == 10
info = memory_store.status()
assert info["tables"]["schema_items"] == 10

Expand All @@ -339,8 +339,8 @@ def wren_memory(tmp_path):
@pytest.mark.unit
class TestWrenMemory:
def test_full_lifecycle(self, wren_memory):
count = wren_memory.index_manifest(_MANIFEST)
assert count == 10
result = wren_memory.index_manifest(_MANIFEST)
assert result["schema_items"] == 10

ctx = wren_memory.get_context(_MANIFEST, "customer")
assert ctx["strategy"] == "full"
Expand All @@ -357,7 +357,83 @@ def test_full_lifecycle(self, wren_memory):

status = wren_memory.status()
assert status["tables"]["schema_items"] == 10
assert status["tables"]["query_history"] == 1
# query_history has 1 user query + seed queries
assert status["tables"]["query_history"] >= 1

wren_memory.reset()
assert wren_memory.status()["tables"] == {}


# ── MemoryStore seed lifecycle tests ─────────────────────────────────────


@pytest.mark.unit
class TestMemoryStoreSeedLifecycle:
def test_index_schema_seeds_query_history(self, memory_store):
from wren.memory.seed_queries import SEED_TAG # noqa: PLC0415

result = memory_store.index_schema(_MANIFEST, seed_queries=True)
assert result["seed_queries"] > 0

import pandas as pd # noqa: PLC0415

table = memory_store._db.open_table("query_history")
df = table.to_pandas()
seeds = df[df["tags"] == SEED_TAG]
assert len(seeds) == result["seed_queries"]

def test_index_schema_no_seed_flag(self, memory_store):
result = memory_store.index_schema(_MANIFEST, seed_queries=False)
assert result["seed_queries"] == 0
from wren.memory.store import _table_names # noqa: PLC0415

assert "query_history" not in _table_names(memory_store._db)

def test_reindex_replaces_seeds_preserves_user_queries(self, memory_store):
from wren.memory.seed_queries import SEED_TAG # noqa: PLC0415

# Index once to create seeds
first = memory_store.index_schema(_MANIFEST, seed_queries=True)
seed_count = first["seed_queries"]
assert seed_count > 0

# Store a user-confirmed query (no seed tag)
memory_store.store_query(
nl_query="show me the most expensive orders",
sql_query="SELECT * FROM orders ORDER BY o_totalprice DESC LIMIT 10",
)

import pandas as pd # noqa: PLC0415

table = memory_store._db.open_table("query_history")
total_before = table.count_rows()
assert total_before == seed_count + 1

# Re-index — seeds should be replaced, user entry preserved
second = memory_store.index_schema(_MANIFEST, seed_queries=True)
assert second["seed_queries"] == seed_count

table = memory_store._db.open_table("query_history")
df = table.to_pandas()
seeds = df[df["tags"] == SEED_TAG]
user_rows = df[df["tags"] != SEED_TAG]
assert len(seeds) == seed_count
assert len(user_rows) == 1
assert "expensive orders" in user_rows.iloc[0]["nl_query"]

def test_recall_returns_seed_entries(self, memory_store):
from wren.memory.seed_queries import SEED_TAG # noqa: PLC0415

memory_store.index_schema(_MANIFEST, seed_queries=True)
results = memory_store.recall_queries("list all orders", limit=5)
assert len(results) > 0
# At least one result should be a seed entry
tags = [r.get("tags", "") for r in results]
assert any(t == SEED_TAG for t in tags)

def test_index_schema_returns_dict(self, memory_store):
result = memory_store.index_schema(_MANIFEST)
assert isinstance(result, dict)
assert "schema_items" in result
assert "seed_queries" in result
assert result["schema_items"] == 10
Loading
Loading