Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions api/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,8 @@ def _ensure_universe_schema(conn: Any) -> None:
conn.execute("ALTER TABLE managers ADD COLUMN cik TEXT")
if "jurisdiction" not in columns:
conn.execute("ALTER TABLE managers ADD COLUMN jurisdiction TEXT")
if "jurisdictions" not in columns:
conn.execute("ALTER TABLE managers ADD COLUMN jurisdictions TEXT NOT NULL DEFAULT '[]'")
if "created_at" not in columns:
conn.execute("ALTER TABLE managers ADD COLUMN created_at TIMESTAMP")
conn.execute(
Expand All @@ -239,6 +241,7 @@ def _ensure_universe_schema(conn: Any) -> None:

conn.execute("ALTER TABLE managers ADD COLUMN IF NOT EXISTS cik text")
conn.execute("ALTER TABLE managers ADD COLUMN IF NOT EXISTS jurisdiction text")
conn.execute("ALTER TABLE managers ADD COLUMN IF NOT EXISTS jurisdictions text[] DEFAULT '{}'")
conn.execute(
"ALTER TABLE managers ADD COLUMN IF NOT EXISTS created_at timestamptz DEFAULT now()"
)
Expand All @@ -258,31 +261,34 @@ def _manager_exists_for_cik(conn: Any, cik: str) -> bool:


def _upsert_universe_record(conn: Any, name: str, cik: str, jurisdiction: str) -> None:
jurisdictions = [jurisdiction]
if isinstance(conn, sqlite3.Connection):
conn.execute(
"""
INSERT INTO managers(name, cik, jurisdiction, updated_at)
VALUES (?, ?, ?, CURRENT_TIMESTAMP)
INSERT INTO managers(name, cik, jurisdiction, jurisdictions, updated_at)
VALUES (?, ?, ?, ?, CURRENT_TIMESTAMP)
ON CONFLICT(cik)
DO UPDATE SET
name = excluded.name,
jurisdiction = excluded.jurisdiction,
jurisdictions = excluded.jurisdictions,
updated_at = CURRENT_TIMESTAMP
""",
(name, cik, jurisdiction),
(name, cik, jurisdiction, json.dumps(jurisdictions)),
)
return
conn.execute(
"""
INSERT INTO managers(name, cik, jurisdiction, updated_at)
VALUES (%s, %s, %s, now())
INSERT INTO managers(name, cik, jurisdiction, jurisdictions, updated_at)
VALUES (%s, %s, %s, %s, now())
ON CONFLICT(cik)
DO UPDATE SET
name = EXCLUDED.name,
jurisdiction = EXCLUDED.jurisdiction,
jurisdictions = EXCLUDED.jurisdictions,
updated_at = now()
""",
(name, cik, jurisdiction),
(name, cik, jurisdiction, jurisdictions),
)


Expand Down
125 changes: 99 additions & 26 deletions scripts/seed_universe.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def _ensure_universe_schema(conn: Any) -> None:
department TEXT,
cik TEXT,
jurisdiction TEXT,
jurisdictions TEXT NOT NULL DEFAULT '[]',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
Expand All @@ -60,6 +61,8 @@ def _ensure_universe_schema(conn: Any) -> None:
conn.execute("ALTER TABLE managers ADD COLUMN cik TEXT")
if "jurisdiction" not in columns:
conn.execute("ALTER TABLE managers ADD COLUMN jurisdiction TEXT")
if "jurisdictions" not in columns:
conn.execute("ALTER TABLE managers ADD COLUMN jurisdictions TEXT NOT NULL DEFAULT '[]'")
if "created_at" not in columns:
conn.execute("ALTER TABLE managers ADD COLUMN created_at TIMESTAMP")
conn.execute(
Expand All @@ -82,12 +85,14 @@ def _ensure_universe_schema(conn: Any) -> None:
department text,
cik text,
jurisdiction text,
jurisdictions text[] DEFAULT '{}',
created_at timestamptz DEFAULT now(),
updated_at timestamptz DEFAULT now()
)
""")
conn.execute("ALTER TABLE managers ADD COLUMN IF NOT EXISTS cik text")
conn.execute("ALTER TABLE managers ADD COLUMN IF NOT EXISTS jurisdiction text")
conn.execute("ALTER TABLE managers ADD COLUMN IF NOT EXISTS jurisdictions text[] DEFAULT '{}'")
conn.execute(
"ALTER TABLE managers ADD COLUMN IF NOT EXISTS created_at timestamptz DEFAULT now()"
)
Expand All @@ -97,13 +102,99 @@ def _ensure_universe_schema(conn: Any) -> None:
conn.execute("CREATE UNIQUE INDEX IF NOT EXISTS idx_managers_cik_unique ON managers(cik)")


def _manager_columns(conn: Any) -> set[str]:
if isinstance(conn, sqlite3.Connection):
return {
str(row[1]).lower() for row in conn.execute("PRAGMA table_info(managers)").fetchall()
}

rows = conn.execute("""
SELECT column_name
FROM information_schema.columns
WHERE table_name = 'managers'
AND table_schema = 'public'
""").fetchall()
return {str(row[0]).lower() for row in rows if row and row[0] is not None}


def _existing_ciks(conn: Any) -> set[str]:
rows = conn.execute(
"SELECT cik FROM managers WHERE cik IS NOT NULL AND TRIM(cik) != ''"
).fetchall()
return {str(row[0]).strip() for row in rows if row and row[0] is not None}


def _upsert_universe_record(
conn: Any,
*,
name: str,
cik: str,
jurisdiction: str,
include_role: bool,
) -> None:
jurisdictions = [jurisdiction]
if isinstance(conn, sqlite3.Connection):
if include_role:
conn.execute(
"""
INSERT INTO managers(name, role, cik, jurisdiction, jurisdictions, updated_at)
VALUES (?, ?, ?, ?, ?, CURRENT_TIMESTAMP)
ON CONFLICT(cik)
DO UPDATE SET
name = excluded.name,
jurisdiction = excluded.jurisdiction,
jurisdictions = excluded.jurisdictions,
updated_at = CURRENT_TIMESTAMP
""",
(name, DEFAULT_ROLE, cik, jurisdiction, json.dumps(jurisdictions)),
)
return
conn.execute(
"""
INSERT INTO managers(name, cik, jurisdiction, jurisdictions, updated_at)
VALUES (?, ?, ?, ?, CURRENT_TIMESTAMP)
ON CONFLICT(cik)
DO UPDATE SET
name = excluded.name,
jurisdiction = excluded.jurisdiction,
jurisdictions = excluded.jurisdictions,
updated_at = CURRENT_TIMESTAMP
""",
(name, cik, jurisdiction, json.dumps(jurisdictions)),
)
return

if include_role:
conn.execute(
"""
INSERT INTO managers(name, role, cik, jurisdiction, jurisdictions, updated_at)
VALUES (%s, %s, %s, %s, %s, now())
ON CONFLICT(cik)
DO UPDATE SET
name = EXCLUDED.name,
jurisdiction = EXCLUDED.jurisdiction,
jurisdictions = EXCLUDED.jurisdictions,
updated_at = now()
""",
(name, DEFAULT_ROLE, cik, jurisdiction, jurisdictions),
)
return

conn.execute(
"""
INSERT INTO managers(name, cik, jurisdiction, jurisdictions, updated_at)
VALUES (%s, %s, %s, %s, now())
ON CONFLICT(cik)
DO UPDATE SET
name = EXCLUDED.name,
jurisdiction = EXCLUDED.jurisdiction,
jurisdictions = EXCLUDED.jurisdictions,
updated_at = now()
""",
(name, cik, jurisdiction, jurisdictions),
)


def seed_universe(file_path: Path, *, dry_run: bool = False) -> tuple[int, int, int]:
records = _load_records(file_path)
conn = connect_db()
Expand All @@ -113,6 +204,7 @@ def seed_universe(file_path: Path, *, dry_run: bool = False) -> tuple[int, int,
try:
_ensure_universe_schema(conn)
known_ciks = _existing_ciks(conn)
include_role = "role" in _manager_columns(conn)

for idx, record in enumerate(records):
name = str(record.get("name", "")).strip()
Expand All @@ -133,32 +225,13 @@ def seed_universe(file_path: Path, *, dry_run: bool = False) -> tuple[int, int,
if dry_run:
continue

if isinstance(conn, sqlite3.Connection):
conn.execute(
"""
INSERT INTO managers(name, role, cik, jurisdiction, updated_at)
VALUES (?, ?, ?, ?, CURRENT_TIMESTAMP)
ON CONFLICT(cik)
DO UPDATE SET
name = excluded.name,
jurisdiction = excluded.jurisdiction,
updated_at = CURRENT_TIMESTAMP
""",
(name, DEFAULT_ROLE, cik, jurisdiction),
)
else:
conn.execute(
"""
INSERT INTO managers(name, role, cik, jurisdiction, updated_at)
VALUES (%s, %s, %s, %s, now())
ON CONFLICT(cik)
DO UPDATE SET
name = EXCLUDED.name,
jurisdiction = EXCLUDED.jurisdiction,
updated_at = now()
""",
(name, DEFAULT_ROLE, cik, jurisdiction),
)
_upsert_universe_record(
conn,
name=name,
cik=cik,
jurisdiction=jurisdiction,
include_role=include_role,
)

if dry_run:
print("Dry run complete. No rows written.")
Expand Down
55 changes: 47 additions & 8 deletions tests/test_manager_universe_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,18 @@ async def _get_manager_stats():
await app.router.shutdown()


async def _get_managers(params: dict[str, str] | None = None):
await app.router.startup()
try:
transport = httpx.ASGITransport(app=app)
async with httpx.AsyncClient(
transport=transport, base_url="http://test", timeout=5.0
) as client:
return await client.get("/managers", params=params)
finally:
await app.router.shutdown()


def test_universe_import_creates_updates_and_skips_records(tmp_path, monkeypatch):
db_path = tmp_path / "dev.db"
monkeypatch.setenv("DB_PATH", str(db_path))
Expand All @@ -60,14 +72,16 @@ def test_universe_import_creates_updates_and_skips_records(tmp_path, monkeypatch

conn = sqlite3.connect(db_path)
try:
rows = conn.execute("SELECT name, cik, jurisdiction FROM managers ORDER BY cik").fetchall()
rows = conn.execute(
"SELECT name, cik, jurisdiction, jurisdictions FROM managers ORDER BY cik"
).fetchall()
finally:
conn.close()

assert rows == [
("Berkshire Hathaway Inc.", "0001067983", "us"),
("Bridgewater Associates", "0001350694", "us"),
("Citadel Advisors", "0001423053", "us"),
("Berkshire Hathaway Inc.", "0001067983", "us", '["us"]'),
("Bridgewater Associates", "0001350694", "us", '["us"]'),
("Citadel Advisors", "0001423053", "us", '["us"]'),
]


Expand Down Expand Up @@ -104,11 +118,13 @@ def test_universe_import_skips_non_object_records(tmp_path, monkeypatch):

conn = sqlite3.connect(db_path)
try:
rows = conn.execute("SELECT name, cik, jurisdiction FROM managers").fetchall()
rows = conn.execute(
"SELECT name, cik, jurisdiction, jurisdictions FROM managers"
).fetchall()
finally:
conn.close()

assert rows == [("Pershing Square Capital Management, L.P.", "0001336528", "us")]
assert rows == [("Pershing Square Capital Management, L.P.", "0001336528", "us", '["us"]')]


def test_universe_import_empty_array_returns_zero_counts(tmp_path, monkeypatch):
Expand Down Expand Up @@ -137,11 +153,13 @@ def test_universe_import_upserts_on_cik_conflict_within_single_request(tmp_path,

conn = sqlite3.connect(db_path)
try:
rows = conn.execute("SELECT name, cik, jurisdiction FROM managers ORDER BY cik").fetchall()
rows = conn.execute(
"SELECT name, cik, jurisdiction, jurisdictions FROM managers ORDER BY cik"
).fetchall()
finally:
conn.close()

assert rows == [("Berkshire Hathaway Inc.", "0001067983", "us")]
assert rows == [("Berkshire Hathaway Inc.", "0001067983", "us", '["us"]')]


def test_manager_stats_uses_universe_jurisdiction_column(tmp_path, monkeypatch):
Expand All @@ -166,3 +184,24 @@ def test_manager_stats_uses_universe_jurisdiction_column(tmp_path, monkeypatch):
"with_cik": 2,
"with_lei": 0,
}


def test_universe_import_populates_jurisdictions_for_list_filter(tmp_path, monkeypatch):
db_path = tmp_path / "dev.db"
monkeypatch.setenv("DB_PATH", str(db_path))
response = asyncio.run(
_post_universe(
[
{"name": "Berkshire Hathaway", "cik": "0001067983", "jurisdiction": "us"},
{"name": "TCI Fund Management", "cik": "0001647251", "jurisdiction": "uk"},
]
)
)
assert response.status_code == 200

filtered = asyncio.run(_get_managers({"jurisdiction": "us"}))
assert filtered.status_code == 200
payload = filtered.json()
assert payload["total"] == 1
assert [item["name"] for item in payload["items"]] == ["Berkshire Hathaway"]
assert payload["items"][0]["jurisdictions"] == ["us"]
8 changes: 5 additions & 3 deletions tests/test_seed_universe.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ def _run_seed(
def _fetch_rows(db_path: Path):
conn = sqlite3.connect(db_path)
try:
return conn.execute("SELECT name, cik, jurisdiction FROM managers ORDER BY cik").fetchall()
return conn.execute(
"SELECT name, cik, jurisdiction, jurisdictions FROM managers ORDER BY cik"
).fetchall()
finally:
conn.close()

Expand All @@ -50,8 +52,8 @@ def test_seed_universe_json_upsert_is_idempotent(tmp_path):

rows = _fetch_rows(tmp_path / "seed.db")
assert rows == [
("Berkshire Hathaway", "0001067983", "us"),
("Bridgewater Associates", "0001350694", "us"),
("Berkshire Hathaway", "0001067983", "us", '["us"]'),
("Bridgewater Associates", "0001350694", "us", '["us"]'),
]


Expand Down
Loading