diff --git a/wren/src/wren/memory/__init__.py b/wren/src/wren/memory/__init__.py index 567f0fda8..834bc9bfb 100644 --- a/wren/src/wren/memory/__init__.py +++ b/wren/src/wren/memory/__init__.py @@ -28,9 +28,20 @@ def __init__(self, path: str | Path | None = None): self._store = MemoryStore(path=path) - def index_manifest(self, manifest: dict, *, replace: bool = True) -> int: - """Index MDL schema into LanceDB. Returns record count.""" - return self._store.index_schema(manifest, replace=replace) + def index_manifest( + self, + manifest: dict, + *, + replace: bool = True, + seed_queries: bool = True, + ) -> dict: + """Index MDL schema into LanceDB. + + Returns {"schema_items": int, "seed_queries": int}. + """ + return self._store.index_schema( + manifest, replace=replace, seed_queries=seed_queries + ) @staticmethod def describe_schema(manifest: dict) -> str: diff --git a/wren/src/wren/memory/cli.py b/wren/src/wren/memory/cli.py index c7e802a78..85c834b94 100644 --- a/wren/src/wren/memory/cli.py +++ b/wren/src/wren/memory/cli.py @@ -125,12 +125,20 @@ def _print_results(results: list[dict], output: str) -> None: def index( mdl: MdlOpt = None, path: PathOpt = None, + no_seed: Annotated[ + bool, + typer.Option("--no-seed", help="Skip generating seed NL-SQL examples."), + ] = False, ) -> None: - """Index MDL schema into LanceDB for semantic search.""" + """Index MDL schema into LanceDB (and optionally seed example queries).""" manifest = _load_manifest(mdl) store = _get_store(path) - count = store.index_schema(manifest) - typer.echo(f"Indexed {count} schema items.") + result = store.index_schema(manifest, seed_queries=not no_seed) + typer.echo( + f"Indexed {result['schema_items']} schema items" + + (f", {result['seed_queries']} seed queries" if result["seed_queries"] else "") + + "." + ) @memory_app.command() diff --git a/wren/src/wren/memory/seed_queries.py b/wren/src/wren/memory/seed_queries.py new file mode 100644 index 000000000..dbd8f45bc --- /dev/null +++ b/wren/src/wren/memory/seed_queries.py @@ -0,0 +1,109 @@ +"""Generate canonical NL-SQL seed pairs from an MDL manifest. + +Pure functions — no LanceDB or embedding dependency. +""" + +from __future__ import annotations + +_NUMERIC_TYPES = { + "int", + "integer", + "bigint", + "smallint", + "tinyint", + "float", + "double", + "decimal", + "numeric", + "real", + "number", +} + +SEED_TAG = "source:seed" + + +def generate_seed_queries(manifest: dict) -> list[dict]: + """Return a list of {"nl": ..., "sql": ...} seed pairs.""" + pairs: list[dict] = [] + + for model in manifest.get("models", []): + pairs.extend(_model_seeds(model)) + + for rel in manifest.get("relationships", []): + pair = _relationship_seed(rel) + if pair: + pairs.append(pair) + + return pairs + + +def _model_seeds(model: dict) -> list[dict]: + name = model["name"] + columns = model.get("columns", []) + pairs = [] + + # Template 1: basic listing + pairs.append( + { + "nl": f"List all {name}", + "sql": f"SELECT * FROM {name} LIMIT 100", + } + ) + + # Find first numeric column (non-calculated) and first groupable column + numeric_col = None + group_col = None + for col in columns: + col_type = (col.get("type") or "").split("(")[0].lower().strip() + is_calc = col.get("isCalculated", False) + is_pk = col["name"] == model.get("primaryKey") + + if ( + col_type in _NUMERIC_TYPES + and not is_calc + and not is_pk + and numeric_col is None + ): + numeric_col = col["name"] + elif ( + col_type not in _NUMERIC_TYPES + and not is_pk + and not is_calc + and group_col is None + ): + group_col = col["name"] + + # Template 2a: simple aggregation + if numeric_col: + pairs.append( + { + "nl": f"Total {numeric_col} in {name}", + "sql": f"SELECT SUM({numeric_col}) FROM {name}", + } + ) + + # Template 2b: grouped aggregation + if numeric_col and group_col: + pairs.append( + { + "nl": f"{numeric_col} by {group_col} in {name}", + "sql": f"SELECT {group_col}, SUM({numeric_col}) FROM {name} GROUP BY 1", + } + ) + + return pairs + + +def _relationship_seed(rel: dict) -> dict | None: + models = rel.get("models", []) + condition = rel.get("condition", "").strip() + + if len(models) < 2 or not condition: + return None + + left, right = models[0], models[1] + + return { + "nl": f"{left} with {right} details", + "sql": f"SELECT * FROM {left} JOIN {right} ON {condition} LIMIT 100", + } diff --git a/wren/src/wren/memory/store.py b/wren/src/wren/memory/store.py index a2fb6a192..17cc05361 100644 --- a/wren/src/wren/memory/store.py +++ b/wren/src/wren/memory/store.py @@ -100,10 +100,20 @@ def _query_table_schema(self) -> pa.Schema: # ── Schema indexing ─────────────────────────────────────────────────── - def index_schema(self, manifest: dict, *, replace: bool = True) -> int: + def index_schema( + self, + manifest: dict, + *, + replace: bool = True, + seed_queries: bool = True, + ) -> dict: """Extract schema items from *manifest*, embed, and store. - Returns the number of records indexed. + If *seed_queries* is True, also generates canonical NL-SQL pairs + and inserts them into query_history (tagged 'source:seed'). + Old seed entries are replaced; user-confirmed entries are preserved. + + Returns {"schema_items": int, "seed_queries": int}. """ items = extract_schema_items(manifest) table_exists = _SCHEMA_TABLE in _table_names(self._db) @@ -111,34 +121,65 @@ def index_schema(self, manifest: dict, *, replace: bool = True) -> int: if not items: if replace and table_exists: self._db.drop_table(_SCHEMA_TABLE) - return 0 - - texts = [item["text"] for item in items] - vectors = self._embed_fn.compute_source_embeddings(texts) + schema_count = 0 + else: + texts = [item["text"] for item in items] + vectors = self._embed_fn.compute_source_embeddings(texts) - for item, vec in zip(items, vectors): - item["vector"] = vec + for item, vec in zip(items, vectors): + item["vector"] = vec - if replace: - if table_exists: - self._db.drop_table(_SCHEMA_TABLE) - self._db.create_table( - _SCHEMA_TABLE, - items, - schema=self._schema_table_schema(), - ) - else: - if table_exists: - tbl = self._db.open_table(_SCHEMA_TABLE) - tbl.add(items) - else: + if replace: + if table_exists: + self._db.drop_table(_SCHEMA_TABLE) self._db.create_table( _SCHEMA_TABLE, items, schema=self._schema_table_schema(), ) + else: + if table_exists: + tbl = self._db.open_table(_SCHEMA_TABLE) + tbl.add(items) + else: + self._db.create_table( + _SCHEMA_TABLE, + items, + schema=self._schema_table_schema(), + ) + schema_count = len(items) + + seed_count = 0 + if seed_queries: + seed_count = self._upsert_seed_queries(manifest) + + return {"schema_items": schema_count, "seed_queries": seed_count} + + def _upsert_seed_queries(self, manifest: dict) -> int: + """Replace seed query entries, preserving user-confirmed ones.""" + from wren.memory.seed_queries import ( # noqa: PLC0415 + SEED_TAG, + generate_seed_queries, + ) + + # Remove old seeds (tagged 'source:seed') but keep user entries + if _QUERY_TABLE in _table_names(self._db): + table = self._db.open_table(_QUERY_TABLE) + table.delete(f"tags = '{SEED_TAG}'") + + pairs = generate_seed_queries(manifest) + if not pairs: + return 0 + + # Insert new seeds via the existing store_query() method + for pair in pairs: + self.store_query( + nl_query=pair["nl"], + sql_query=pair["sql"], + tags=SEED_TAG, + ) - return len(items) + return len(pairs) def schema_is_current(self, manifest: dict) -> bool: """Check whether the indexed schema matches *manifest*. diff --git a/wren/tests/unit/test_memory.py b/wren/tests/unit/test_memory.py index c4b288f76..d6b76d219 100644 --- a/wren/tests/unit/test_memory.py +++ b/wren/tests/unit/test_memory.py @@ -236,13 +236,13 @@ def memory_store(tmp_path): @pytest.mark.unit class TestMemoryStore: def test_index_and_context(self, memory_store): - count = memory_store.index_schema(_MANIFEST) - assert count == 10 + result = memory_store.index_schema(_MANIFEST) + assert result["schema_items"] == 10 # Small schema → full strategy - result = memory_store.get_context(_MANIFEST, "customer order price") - assert result["strategy"] == "full" - assert "### Model: orders" in result["schema"] + ctx = memory_store.get_context(_MANIFEST, "customer order price") + assert ctx["strategy"] == "full" + assert "### Model: orders" in ctx["schema"] def test_context_search_strategy(self, memory_store): memory_store.index_schema(_MANIFEST) @@ -316,8 +316,8 @@ def test_describe_schema_static(self, memory_store): def test_index_replace(self, memory_store): memory_store.index_schema(_MANIFEST) - count = memory_store.index_schema(_MANIFEST, replace=True) - assert count == 10 + result = memory_store.index_schema(_MANIFEST, replace=True) + assert result["schema_items"] == 10 info = memory_store.status() assert info["tables"]["schema_items"] == 10 @@ -339,8 +339,8 @@ def wren_memory(tmp_path): @pytest.mark.unit class TestWrenMemory: def test_full_lifecycle(self, wren_memory): - count = wren_memory.index_manifest(_MANIFEST) - assert count == 10 + result = wren_memory.index_manifest(_MANIFEST) + assert result["schema_items"] == 10 ctx = wren_memory.get_context(_MANIFEST, "customer") assert ctx["strategy"] == "full" @@ -357,7 +357,83 @@ def test_full_lifecycle(self, wren_memory): status = wren_memory.status() assert status["tables"]["schema_items"] == 10 - assert status["tables"]["query_history"] == 1 + # query_history has 1 user query + seed queries + assert status["tables"]["query_history"] >= 1 wren_memory.reset() assert wren_memory.status()["tables"] == {} + + +# ── MemoryStore seed lifecycle tests ───────────────────────────────────── + + +@pytest.mark.unit +class TestMemoryStoreSeedLifecycle: + def test_index_schema_seeds_query_history(self, memory_store): + from wren.memory.seed_queries import SEED_TAG # noqa: PLC0415 + + result = memory_store.index_schema(_MANIFEST, seed_queries=True) + assert result["seed_queries"] > 0 + + import pandas as pd # noqa: PLC0415 + + table = memory_store._db.open_table("query_history") + df = table.to_pandas() + seeds = df[df["tags"] == SEED_TAG] + assert len(seeds) == result["seed_queries"] + + def test_index_schema_no_seed_flag(self, memory_store): + result = memory_store.index_schema(_MANIFEST, seed_queries=False) + assert result["seed_queries"] == 0 + from wren.memory.store import _table_names # noqa: PLC0415 + + assert "query_history" not in _table_names(memory_store._db) + + def test_reindex_replaces_seeds_preserves_user_queries(self, memory_store): + from wren.memory.seed_queries import SEED_TAG # noqa: PLC0415 + + # Index once to create seeds + first = memory_store.index_schema(_MANIFEST, seed_queries=True) + seed_count = first["seed_queries"] + assert seed_count > 0 + + # Store a user-confirmed query (no seed tag) + memory_store.store_query( + nl_query="show me the most expensive orders", + sql_query="SELECT * FROM orders ORDER BY o_totalprice DESC LIMIT 10", + ) + + import pandas as pd # noqa: PLC0415 + + table = memory_store._db.open_table("query_history") + total_before = table.count_rows() + assert total_before == seed_count + 1 + + # Re-index — seeds should be replaced, user entry preserved + second = memory_store.index_schema(_MANIFEST, seed_queries=True) + assert second["seed_queries"] == seed_count + + table = memory_store._db.open_table("query_history") + df = table.to_pandas() + seeds = df[df["tags"] == SEED_TAG] + user_rows = df[df["tags"] != SEED_TAG] + assert len(seeds) == seed_count + assert len(user_rows) == 1 + assert "expensive orders" in user_rows.iloc[0]["nl_query"] + + def test_recall_returns_seed_entries(self, memory_store): + from wren.memory.seed_queries import SEED_TAG # noqa: PLC0415 + + memory_store.index_schema(_MANIFEST, seed_queries=True) + results = memory_store.recall_queries("list all orders", limit=5) + assert len(results) > 0 + # At least one result should be a seed entry + tags = [r.get("tags", "") for r in results] + assert any(t == SEED_TAG for t in tags) + + def test_index_schema_returns_dict(self, memory_store): + result = memory_store.index_schema(_MANIFEST) + assert isinstance(result, dict) + assert "schema_items" in result + assert "seed_queries" in result + assert result["schema_items"] == 10 diff --git a/wren/tests/unit/test_seed_queries.py b/wren/tests/unit/test_seed_queries.py new file mode 100644 index 000000000..42c16318a --- /dev/null +++ b/wren/tests/unit/test_seed_queries.py @@ -0,0 +1,292 @@ +"""Unit tests for wren.memory.seed_queries.""" + +from __future__ import annotations + +import pytest + +from wren.memory.seed_queries import SEED_TAG, generate_seed_queries + + +# ── Fixtures ────────────────────────────────────────────────────────────── + + +def _model(name: str, pk: str, columns: list[dict]) -> dict: + return {"name": name, "primaryKey": pk, "columns": columns} + + +def _col(name: str, col_type: str, *, is_calc: bool = False) -> dict: + return {"name": name, "type": col_type, "isCalculated": is_calc} + + +# ── generate_seed_queries tests ─────────────────────────────────────────── + + +@pytest.mark.unit +class TestGenerateSeedQueries: + def test_empty_manifest(self): + assert generate_seed_queries({}) == [] + + def test_empty_models_list(self): + assert generate_seed_queries({"models": []}) == [] + + def test_single_model_all_string_columns(self): + manifest = { + "models": [ + _model( + "orders", + "id", + [ + _col("id", "varchar"), + _col("status", "varchar"), + _col("note", "varchar"), + ], + ) + ] + } + pairs = generate_seed_queries(manifest) + assert len(pairs) == 1 + assert pairs[0]["nl"] == "List all orders" + assert "SELECT * FROM orders" in pairs[0]["sql"] + + def test_single_model_only_pk_column(self): + # PK is varchar — no numeric col, no group col → listing only + manifest = { + "models": [_model("products", "product_id", [_col("product_id", "varchar")])] + } + pairs = generate_seed_queries(manifest) + assert len(pairs) == 1 + assert pairs[0]["nl"] == "List all products" + + def test_numeric_pk_not_used_for_aggregation(self): + # PK is numeric — should not be picked as aggregate target + manifest = { + "models": [ + _model( + "orders", + "order_id", + [ + _col("order_id", "int"), + _col("status", "varchar"), + ], + ) + ] + } + pairs = generate_seed_queries(manifest) + # Only listing — no aggregation since numeric col is a PK + assert len(pairs) == 1 + assert pairs[0]["nl"] == "List all orders" + + def test_single_model_one_numeric_no_group(self): + # Only columns: pk (varchar) + amount (double) — no eligible group col + manifest = { + "models": [ + _model( + "payments", + "payment_id", + [ + _col("payment_id", "varchar"), + _col("amount", "double"), + ], + ) + ] + } + pairs = generate_seed_queries(manifest) + nls = [p["nl"] for p in pairs] + assert "List all payments" in nls + assert "Total amount in payments" in nls + # No group col → no grouped aggregation + assert not any("by" in nl for nl in nls) + assert len(pairs) == 2 + + def test_single_model_numeric_and_group_column(self): + # 3 columns: pk (varchar) + status (varchar) + amount (double) + manifest = { + "models": [ + _model( + "orders", + "order_id", + [ + _col("order_id", "varchar"), + _col("status", "varchar"), + _col("amount", "double"), + ], + ) + ] + } + pairs = generate_seed_queries(manifest) + assert len(pairs) == 3 + nls = [p["nl"] for p in pairs] + assert "List all orders" in nls + assert "Total amount in orders" in nls + assert "amount by status in orders" in nls + + def test_grouped_aggregation_sql(self): + manifest = { + "models": [ + _model( + "sales", + "id", + [ + _col("id", "varchar"), + _col("region", "varchar"), + _col("revenue", "decimal"), + ], + ) + ] + } + pairs = generate_seed_queries(manifest) + grouped = next(p for p in pairs if "by" in p["nl"]) + assert grouped["sql"] == "SELECT region, SUM(revenue) FROM sales GROUP BY 1" + + def test_calculated_columns_skipped_for_aggregation(self): + # The only numeric col is calculated — should be skipped for agg template + manifest = { + "models": [ + _model( + "orders", + "id", + [ + _col("id", "varchar"), + _col("status", "varchar"), + _col("total", "double", is_calc=True), + ], + ) + ] + } + pairs = generate_seed_queries(manifest) + assert len(pairs) == 1 + assert pairs[0]["nl"] == "List all orders" + + def test_numeric_type_case_insensitive(self): + # Type stored in mixed case (e.g. "BIGINT", "Float") + manifest = { + "models": [ + _model( + "metrics", + "id", + [ + _col("id", "varchar"), + _col("category", "varchar"), + _col("count", "BIGINT"), + ], + ) + ] + } + pairs = generate_seed_queries(manifest) + assert any("Total count" in p["nl"] for p in pairs) + + def test_numeric_type_with_precision(self): + # Type like "DECIMAL(10,2)" — precision should be stripped + manifest = { + "models": [ + _model( + "invoices", + "id", + [ + _col("id", "varchar"), + _col("customer", "varchar"), + _col("amount", "DECIMAL(10,2)"), + ], + ) + ] + } + pairs = generate_seed_queries(manifest) + assert any("Total amount" in p["nl"] for p in pairs) + + def test_two_models_with_relationship(self): + manifest = { + "models": [ + _model( + "orders", + "order_id", + [ + _col("order_id", "varchar"), + _col("status", "varchar"), + _col("amount", "double"), + ], + ), + _model( + "customers", + "customer_id", + [ + _col("customer_id", "varchar"), + _col("name", "varchar"), + _col("score", "int"), + ], + ), + ], + "relationships": [ + { + "name": "orders_customers", + "models": ["orders", "customers"], + "condition": "orders.customer_id = customers.customer_id", + } + ], + } + pairs = generate_seed_queries(manifest) + # orders: 3 (listing + agg + grouped), customers: 3 (listing + agg + grouped), rel: 1 + assert len(pairs) == 7 + nls = [p["nl"] for p in pairs] + assert "orders with customers details" in nls + + def test_relationship_join_sql(self): + manifest = { + "models": [ + _model("a", "aid", [_col("aid", "varchar")]), + _model("b", "bid", [_col("bid", "varchar")]), + ], + "relationships": [ + { + "name": "a_b", + "models": ["a", "b"], + "condition": "a.bid = b.bid", + } + ], + } + pairs = generate_seed_queries(manifest) + join_pair = next(p for p in pairs if "with" in p["nl"]) + assert join_pair["sql"] == "SELECT * FROM a JOIN b ON a.bid = b.bid LIMIT 100" + + def test_relationship_missing_condition_skipped(self): + manifest = { + "models": [ + _model("a", "aid", [_col("aid", "varchar")]), + _model("b", "bid", [_col("bid", "varchar")]), + ], + "relationships": [ + {"name": "a_b", "models": ["a", "b"], "condition": ""} + ], + } + pairs = generate_seed_queries(manifest) + assert not any("with" in p["nl"] for p in pairs) + + def test_relationship_whitespace_condition_skipped(self): + manifest = { + "models": [ + _model("a", "aid", [_col("aid", "varchar")]), + _model("b", "bid", [_col("bid", "varchar")]), + ], + "relationships": [ + {"name": "a_b", "models": ["a", "b"], "condition": " "} + ], + } + pairs = generate_seed_queries(manifest) + assert not any("with" in p["nl"] for p in pairs) + + def test_relationship_fewer_than_two_models_skipped(self): + manifest = { + "models": [_model("a", "aid", [_col("aid", "varchar")])], + "relationships": [ + {"name": "a_self", "models": ["a"], "condition": "a.id = a.id"} + ], + } + pairs = generate_seed_queries(manifest) + assert not any("with" in p["nl"] for p in pairs) + + def test_seed_tag_constant(self): + assert SEED_TAG == "source:seed" + + def test_listing_sql_has_limit(self): + manifest = {"models": [_model("t", "id", [_col("id", "varchar")])]} + pairs = generate_seed_queries(manifest) + assert "LIMIT 100" in pairs[0]["sql"]