diff --git a/api/managers.py b/api/managers.py index 4aa26597..7ab56c01 100644 --- a/api/managers.py +++ b/api/managers.py @@ -223,6 +223,8 @@ def _ensure_universe_schema(conn: Any) -> None: conn.execute("ALTER TABLE managers ADD COLUMN cik TEXT") if "jurisdiction" not in columns: conn.execute("ALTER TABLE managers ADD COLUMN jurisdiction TEXT") + if "jurisdictions" not in columns: + conn.execute("ALTER TABLE managers ADD COLUMN jurisdictions TEXT NOT NULL DEFAULT '[]'") if "created_at" not in columns: conn.execute("ALTER TABLE managers ADD COLUMN created_at TIMESTAMP") conn.execute( @@ -239,6 +241,7 @@ def _ensure_universe_schema(conn: Any) -> None: conn.execute("ALTER TABLE managers ADD COLUMN IF NOT EXISTS cik text") conn.execute("ALTER TABLE managers ADD COLUMN IF NOT EXISTS jurisdiction text") + conn.execute("ALTER TABLE managers ADD COLUMN IF NOT EXISTS jurisdictions text[] DEFAULT '{}'") conn.execute( "ALTER TABLE managers ADD COLUMN IF NOT EXISTS created_at timestamptz DEFAULT now()" ) @@ -258,31 +261,34 @@ def _manager_exists_for_cik(conn: Any, cik: str) -> bool: def _upsert_universe_record(conn: Any, name: str, cik: str, jurisdiction: str) -> None: + jurisdictions = [jurisdiction] if isinstance(conn, sqlite3.Connection): conn.execute( """ - INSERT INTO managers(name, cik, jurisdiction, updated_at) - VALUES (?, ?, ?, CURRENT_TIMESTAMP) + INSERT INTO managers(name, cik, jurisdiction, jurisdictions, updated_at) + VALUES (?, ?, ?, ?, CURRENT_TIMESTAMP) ON CONFLICT(cik) DO UPDATE SET name = excluded.name, jurisdiction = excluded.jurisdiction, + jurisdictions = excluded.jurisdictions, updated_at = CURRENT_TIMESTAMP """, - (name, cik, jurisdiction), + (name, cik, jurisdiction, json.dumps(jurisdictions)), ) return conn.execute( """ - INSERT INTO managers(name, cik, jurisdiction, updated_at) - VALUES (%s, %s, %s, now()) + INSERT INTO managers(name, cik, jurisdiction, jurisdictions, updated_at) + VALUES (%s, %s, %s, %s, now()) ON CONFLICT(cik) DO UPDATE SET name = EXCLUDED.name, jurisdiction = EXCLUDED.jurisdiction, + jurisdictions = EXCLUDED.jurisdictions, updated_at = now() """, - (name, cik, jurisdiction), + (name, cik, jurisdiction, jurisdictions), ) diff --git a/scripts/seed_universe.py b/scripts/seed_universe.py index a85dc83d..938051ff 100644 --- a/scripts/seed_universe.py +++ b/scripts/seed_universe.py @@ -51,6 +51,7 @@ def _ensure_universe_schema(conn: Any) -> None: department TEXT, cik TEXT, jurisdiction TEXT, + jurisdictions TEXT NOT NULL DEFAULT '[]', created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) @@ -60,6 +61,8 @@ def _ensure_universe_schema(conn: Any) -> None: conn.execute("ALTER TABLE managers ADD COLUMN cik TEXT") if "jurisdiction" not in columns: conn.execute("ALTER TABLE managers ADD COLUMN jurisdiction TEXT") + if "jurisdictions" not in columns: + conn.execute("ALTER TABLE managers ADD COLUMN jurisdictions TEXT NOT NULL DEFAULT '[]'") if "created_at" not in columns: conn.execute("ALTER TABLE managers ADD COLUMN created_at TIMESTAMP") conn.execute( @@ -82,12 +85,14 @@ def _ensure_universe_schema(conn: Any) -> None: department text, cik text, jurisdiction text, + jurisdictions text[] DEFAULT '{}', created_at timestamptz DEFAULT now(), updated_at timestamptz DEFAULT now() ) """) conn.execute("ALTER TABLE managers ADD COLUMN IF NOT EXISTS cik text") conn.execute("ALTER TABLE managers ADD COLUMN IF NOT EXISTS jurisdiction text") + conn.execute("ALTER TABLE managers ADD COLUMN IF NOT EXISTS jurisdictions text[] DEFAULT '{}'") conn.execute( "ALTER TABLE managers ADD COLUMN IF NOT EXISTS created_at timestamptz DEFAULT now()" ) @@ -97,6 +102,21 @@ def _ensure_universe_schema(conn: Any) -> None: conn.execute("CREATE UNIQUE INDEX IF NOT EXISTS idx_managers_cik_unique ON managers(cik)") +def _manager_columns(conn: Any) -> set[str]: + if isinstance(conn, sqlite3.Connection): + return { + str(row[1]).lower() for row in conn.execute("PRAGMA table_info(managers)").fetchall() + } + + rows = conn.execute(""" + SELECT column_name + FROM information_schema.columns + WHERE table_name = 'managers' + AND table_schema = 'public' + """).fetchall() + return {str(row[0]).lower() for row in rows if row and row[0] is not None} + + def _existing_ciks(conn: Any) -> set[str]: rows = conn.execute( "SELECT cik FROM managers WHERE cik IS NOT NULL AND TRIM(cik) != ''" @@ -104,6 +124,77 @@ def _existing_ciks(conn: Any) -> set[str]: return {str(row[0]).strip() for row in rows if row and row[0] is not None} +def _upsert_universe_record( + conn: Any, + *, + name: str, + cik: str, + jurisdiction: str, + include_role: bool, +) -> None: + jurisdictions = [jurisdiction] + if isinstance(conn, sqlite3.Connection): + if include_role: + conn.execute( + """ + INSERT INTO managers(name, role, cik, jurisdiction, jurisdictions, updated_at) + VALUES (?, ?, ?, ?, ?, CURRENT_TIMESTAMP) + ON CONFLICT(cik) + DO UPDATE SET + name = excluded.name, + jurisdiction = excluded.jurisdiction, + jurisdictions = excluded.jurisdictions, + updated_at = CURRENT_TIMESTAMP + """, + (name, DEFAULT_ROLE, cik, jurisdiction, json.dumps(jurisdictions)), + ) + return + conn.execute( + """ + INSERT INTO managers(name, cik, jurisdiction, jurisdictions, updated_at) + VALUES (?, ?, ?, ?, CURRENT_TIMESTAMP) + ON CONFLICT(cik) + DO UPDATE SET + name = excluded.name, + jurisdiction = excluded.jurisdiction, + jurisdictions = excluded.jurisdictions, + updated_at = CURRENT_TIMESTAMP + """, + (name, cik, jurisdiction, json.dumps(jurisdictions)), + ) + return + + if include_role: + conn.execute( + """ + INSERT INTO managers(name, role, cik, jurisdiction, jurisdictions, updated_at) + VALUES (%s, %s, %s, %s, %s, now()) + ON CONFLICT(cik) + DO UPDATE SET + name = EXCLUDED.name, + jurisdiction = EXCLUDED.jurisdiction, + jurisdictions = EXCLUDED.jurisdictions, + updated_at = now() + """, + (name, DEFAULT_ROLE, cik, jurisdiction, jurisdictions), + ) + return + + conn.execute( + """ + INSERT INTO managers(name, cik, jurisdiction, jurisdictions, updated_at) + VALUES (%s, %s, %s, %s, now()) + ON CONFLICT(cik) + DO UPDATE SET + name = EXCLUDED.name, + jurisdiction = EXCLUDED.jurisdiction, + jurisdictions = EXCLUDED.jurisdictions, + updated_at = now() + """, + (name, cik, jurisdiction, jurisdictions), + ) + + def seed_universe(file_path: Path, *, dry_run: bool = False) -> tuple[int, int, int]: records = _load_records(file_path) conn = connect_db() @@ -113,6 +204,7 @@ def seed_universe(file_path: Path, *, dry_run: bool = False) -> tuple[int, int, try: _ensure_universe_schema(conn) known_ciks = _existing_ciks(conn) + include_role = "role" in _manager_columns(conn) for idx, record in enumerate(records): name = str(record.get("name", "")).strip() @@ -133,32 +225,13 @@ def seed_universe(file_path: Path, *, dry_run: bool = False) -> tuple[int, int, if dry_run: continue - if isinstance(conn, sqlite3.Connection): - conn.execute( - """ - INSERT INTO managers(name, role, cik, jurisdiction, updated_at) - VALUES (?, ?, ?, ?, CURRENT_TIMESTAMP) - ON CONFLICT(cik) - DO UPDATE SET - name = excluded.name, - jurisdiction = excluded.jurisdiction, - updated_at = CURRENT_TIMESTAMP - """, - (name, DEFAULT_ROLE, cik, jurisdiction), - ) - else: - conn.execute( - """ - INSERT INTO managers(name, role, cik, jurisdiction, updated_at) - VALUES (%s, %s, %s, %s, now()) - ON CONFLICT(cik) - DO UPDATE SET - name = EXCLUDED.name, - jurisdiction = EXCLUDED.jurisdiction, - updated_at = now() - """, - (name, DEFAULT_ROLE, cik, jurisdiction), - ) + _upsert_universe_record( + conn, + name=name, + cik=cik, + jurisdiction=jurisdiction, + include_role=include_role, + ) if dry_run: print("Dry run complete. No rows written.") diff --git a/tests/test_manager_universe_api.py b/tests/test_manager_universe_api.py index 07838e78..226d684e 100644 --- a/tests/test_manager_universe_api.py +++ b/tests/test_manager_universe_api.py @@ -34,6 +34,18 @@ async def _get_manager_stats(): await app.router.shutdown() +async def _get_managers(params: dict[str, str] | None = None): + await app.router.startup() + try: + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient( + transport=transport, base_url="http://test", timeout=5.0 + ) as client: + return await client.get("/managers", params=params) + finally: + await app.router.shutdown() + + def test_universe_import_creates_updates_and_skips_records(tmp_path, monkeypatch): db_path = tmp_path / "dev.db" monkeypatch.setenv("DB_PATH", str(db_path)) @@ -60,14 +72,16 @@ def test_universe_import_creates_updates_and_skips_records(tmp_path, monkeypatch conn = sqlite3.connect(db_path) try: - rows = conn.execute("SELECT name, cik, jurisdiction FROM managers ORDER BY cik").fetchall() + rows = conn.execute( + "SELECT name, cik, jurisdiction, jurisdictions FROM managers ORDER BY cik" + ).fetchall() finally: conn.close() assert rows == [ - ("Berkshire Hathaway Inc.", "0001067983", "us"), - ("Bridgewater Associates", "0001350694", "us"), - ("Citadel Advisors", "0001423053", "us"), + ("Berkshire Hathaway Inc.", "0001067983", "us", '["us"]'), + ("Bridgewater Associates", "0001350694", "us", '["us"]'), + ("Citadel Advisors", "0001423053", "us", '["us"]'), ] @@ -104,11 +118,13 @@ def test_universe_import_skips_non_object_records(tmp_path, monkeypatch): conn = sqlite3.connect(db_path) try: - rows = conn.execute("SELECT name, cik, jurisdiction FROM managers").fetchall() + rows = conn.execute( + "SELECT name, cik, jurisdiction, jurisdictions FROM managers" + ).fetchall() finally: conn.close() - assert rows == [("Pershing Square Capital Management, L.P.", "0001336528", "us")] + assert rows == [("Pershing Square Capital Management, L.P.", "0001336528", "us", '["us"]')] def test_universe_import_empty_array_returns_zero_counts(tmp_path, monkeypatch): @@ -137,11 +153,13 @@ def test_universe_import_upserts_on_cik_conflict_within_single_request(tmp_path, conn = sqlite3.connect(db_path) try: - rows = conn.execute("SELECT name, cik, jurisdiction FROM managers ORDER BY cik").fetchall() + rows = conn.execute( + "SELECT name, cik, jurisdiction, jurisdictions FROM managers ORDER BY cik" + ).fetchall() finally: conn.close() - assert rows == [("Berkshire Hathaway Inc.", "0001067983", "us")] + assert rows == [("Berkshire Hathaway Inc.", "0001067983", "us", '["us"]')] def test_manager_stats_uses_universe_jurisdiction_column(tmp_path, monkeypatch): @@ -166,3 +184,24 @@ def test_manager_stats_uses_universe_jurisdiction_column(tmp_path, monkeypatch): "with_cik": 2, "with_lei": 0, } + + +def test_universe_import_populates_jurisdictions_for_list_filter(tmp_path, monkeypatch): + db_path = tmp_path / "dev.db" + monkeypatch.setenv("DB_PATH", str(db_path)) + response = asyncio.run( + _post_universe( + [ + {"name": "Berkshire Hathaway", "cik": "0001067983", "jurisdiction": "us"}, + {"name": "TCI Fund Management", "cik": "0001647251", "jurisdiction": "uk"}, + ] + ) + ) + assert response.status_code == 200 + + filtered = asyncio.run(_get_managers({"jurisdiction": "us"})) + assert filtered.status_code == 200 + payload = filtered.json() + assert payload["total"] == 1 + assert [item["name"] for item in payload["items"]] == ["Berkshire Hathaway"] + assert payload["items"][0]["jurisdictions"] == ["us"] diff --git a/tests/test_seed_universe.py b/tests/test_seed_universe.py index 041e429f..0a8dd9b1 100644 --- a/tests/test_seed_universe.py +++ b/tests/test_seed_universe.py @@ -27,7 +27,9 @@ def _run_seed( def _fetch_rows(db_path: Path): conn = sqlite3.connect(db_path) try: - return conn.execute("SELECT name, cik, jurisdiction FROM managers ORDER BY cik").fetchall() + return conn.execute( + "SELECT name, cik, jurisdiction, jurisdictions FROM managers ORDER BY cik" + ).fetchall() finally: conn.close() @@ -50,8 +52,8 @@ def test_seed_universe_json_upsert_is_idempotent(tmp_path): rows = _fetch_rows(tmp_path / "seed.db") assert rows == [ - ("Berkshire Hathaway", "0001067983", "us"), - ("Bridgewater Associates", "0001350694", "us"), + ("Berkshire Hathaway", "0001067983", "us", '["us"]'), + ("Bridgewater Associates", "0001350694", "us", '["us"]'), ]