Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
259 changes: 240 additions & 19 deletions api/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from datetime import date
from threading import Lock
from typing import Any, cast

Expand Down Expand Up @@ -453,6 +454,106 @@ def _build_chat_client_info():
return build_chat_client()


def _is_sqlite_connection(conn: Any) -> bool:
return isinstance(conn, sqlite3.Connection)


def _manager_id_column(conn: Any) -> str:
if _is_sqlite_connection(conn):
rows = conn.execute("PRAGMA table_info(managers)").fetchall()
columns = {str(row[1]) for row in rows}
else:
rows = conn.execute(
"SELECT column_name FROM information_schema.columns "
"WHERE table_schema = current_schema() AND table_name = %s",
("managers",),
).fetchall()
columns = {str(row[0]) for row in rows}
return "manager_id" if "manager_id" in columns else "id"


def _normalize_manager_ids(value: Any) -> list[int]:
raw_values = value if isinstance(value, list) else [value]
manager_ids: list[int] = []
for raw_value in raw_values:
try:
manager_ids.append(int(raw_value))
except (TypeError, ValueError):
continue
return manager_ids


def _normalize_cusips(value: Any) -> list[str]:
raw_values = value if isinstance(value, list) else [value]
cusips: list[str] = []
for raw_value in raw_values:
text = str(raw_value or "").strip().upper()
if text:
cusips.append(text)
return cusips


def _parse_date_range(value: Any) -> tuple[date, date] | None:
if isinstance(value, dict):
start_raw = value.get("start")
end_raw = value.get("end")
elif isinstance(value, (list, tuple)) and len(value) == 2:
start_raw, end_raw = value
else:
return None
if not start_raw or not end_raw:
return None
try:
return (date.fromisoformat(str(start_raw)), date.fromisoformat(str(end_raw)))
except ValueError:
return None


def _manager_ids_for_name(conn: Any, manager_name: str) -> list[int]:
normalized_name = manager_name.strip()
if not normalized_name:
return []
manager_id_col = _manager_id_column(conn)
placeholder = "?" if _is_sqlite_connection(conn) else "%s"
rows = conn.execute(
f"SELECT {manager_id_col} FROM managers WHERE lower(name) = lower({placeholder})",
(normalized_name,),
).fetchall()
manager_ids: list[int] = []
for row in rows:
try:
manager_ids.append(int(row[0]))
except (TypeError, ValueError, IndexError):
continue
return manager_ids


def _normalize_chain_context(context: dict[str, Any] | None, conn: Any) -> dict[str, Any]:
if not context:
return {}

normalized = dict(context)
manager_ids = _normalize_manager_ids(normalized.get("manager_ids"))
if not manager_ids:
manager_ids = _normalize_manager_ids(normalized.get("manager_id"))
if not manager_ids and isinstance(normalized.get("manager_name"), str):
manager_ids = _manager_ids_for_name(conn, str(normalized["manager_name"]))
if manager_ids:
normalized["manager_ids"] = manager_ids

cusips = _normalize_cusips(normalized.get("cusips"))
if cusips:
normalized["cusips"] = cusips

parsed_date_range = _parse_date_range(normalized.get("date_range"))
if parsed_date_range is not None:
normalized["date_range"] = {
"start": parsed_date_range[0].isoformat(),
"end": parsed_date_range[1].isoformat(),
}
return normalized


def _classify_intent(question: str) -> str:
"""Classify user intent or use a deterministic fallback classifier."""
try:
Expand All @@ -475,15 +576,28 @@ def _classify_intent(question: str) -> str:


class _FallbackFilingSummaryChain:
def run(self, *, question: str, context: dict[str, Any] | None = None) -> dict[str, Any]:
filing_id = (context or {}).get("filing_id")
def run(self, filing_id: int) -> dict[str, Any]:
detail = f" for filing {filing_id}" if filing_id else ""
return {"answer": f"Filing summary{detail}: {question}", "sources": []}
return {"answer": f"Filing summary{detail}", "sources": []}


class _FallbackHoldingsAnalysisChain:
def run(self, *, question: str, context: dict[str, Any] | None = None) -> dict[str, Any]:
return {"answer": f"Holdings analysis: {question}", "sources": context or {}}
def run(
self,
question: str,
*,
manager_ids: list[int] | None = None,
cusips: list[str] | None = None,
date_range: tuple[date, date] | None = None,
) -> dict[str, Any]:
context: dict[str, Any] = {}
if manager_ids:
context["manager_ids"] = manager_ids
if cusips:
context["cusips"] = cusips
if date_range is not None:
context["date_range"] = [item.isoformat() for item in date_range]
return {"answer": f"Holdings analysis: {question}", "sources": [context] if context else []}


class _FallbackNLQueryChain:
Expand Down Expand Up @@ -512,15 +626,18 @@ def _build_chain(chain_name: str, client_info: Any):
module = importlib.import_module(module_path)
chain_cls = getattr(module, class_name)
except Exception:
chain_cls = FALLBACK_CHAINS[chain_name]
return FALLBACK_CHAINS[chain_name]()

db_conn = connect_db()
try:
return chain_cls(client_info.client if client_info else None)
if chain_name == "filing_summary":
return chain_cls(client_info=client_info, db_conn=db_conn)
if chain_name == "holdings_analysis":
return chain_cls(client_info=client_info, db_conn=db_conn)
return chain_cls(llm=client_info.client if client_info else None, db_conn=db_conn)
except Exception:
try:
return chain_cls(client_info) if client_info else chain_cls()
except Exception:
return chain_cls()
db_conn.close()
raise


def _normalize_sources(raw_sources: Any) -> list[dict[str, Any]]:
Expand Down Expand Up @@ -555,6 +672,72 @@ def _extract_chain_payload(result: Any) -> tuple[str, list[dict[str, Any]], str
return answer, sources, sql, trace_url


def _format_filing_summary_payload(result: Any) -> dict[str, Any]:
payload = result.model_dump() if hasattr(result, "model_dump") else dict(result)
manager_name = str(payload.get("manager_name") or "Unknown manager")
filing_date = str(payload.get("filing_date") or "n/a")
total_positions = payload.get("total_positions")
total_aum_estimate = payload.get("total_aum_estimate") or "n/a"
lines = [
f"{manager_name} filing summary ({filing_date})",
f"Total positions: {total_positions}",
f"Estimated AUM: {total_aum_estimate}",
]
key_positions = payload.get("key_positions") or []
if isinstance(key_positions, list) and key_positions:
top_names = [
str(
position.get("name_of_issuer")
or position.get("issuer")
or position.get("cusip")
or ""
)
for position in key_positions[:5]
if isinstance(position, dict)
]
top_names = [name for name in top_names if name]
if top_names:
lines.append("Top positions: " + ", ".join(top_names))
notable_changes = payload.get("notable_changes") or []
if isinstance(notable_changes, list) and notable_changes:
lines.append("Notable changes: " + "; ".join(str(change) for change in notable_changes[:3]))
risk_flags = payload.get("risk_flags") or []
if isinstance(risk_flags, list) and risk_flags:
lines.append("Risk flags: " + "; ".join(str(flag) for flag in risk_flags[:3]))
return {"answer": "\n".join(lines), "sources": [], "sql": None, "trace_url": None}


def _format_holdings_analysis_payload(result: Any) -> dict[str, Any]:
payload = result.model_dump() if hasattr(result, "model_dump") else dict(result)
thesis = str(payload.get("thesis") or "Holdings analysis complete.")
lines = [thesis]
top_positions = payload.get("top_positions") or []
if isinstance(top_positions, list) and top_positions:
top_names = [
str(
position.get("name_of_issuer")
or position.get("issuer")
or position.get("cusip")
or ""
)
for position in top_positions[:5]
if isinstance(position, dict)
]
top_names = [name for name in top_names if name]
if top_names:
lines.append("Top positions: " + ", ".join(top_names))
period_changes = payload.get("period_changes") or []
if isinstance(period_changes, list) and period_changes:
lines.append(
"Period changes: "
+ "; ".join(str(change.get("summary") or change) for change in period_changes[:3])
)
concentration_metrics = payload.get("concentration_metrics") or {}
if isinstance(concentration_metrics, dict) and concentration_metrics:
lines.append("Concentration metrics: " + json.dumps(concentration_metrics, default=str))
return {"answer": "\n".join(lines), "sources": [], "sql": None, "trace_url": None}


def _response_id_from_trace_url(trace_url: str | None) -> str:
if trace_url:
stripped = trace_url.rstrip("/")
Expand Down Expand Up @@ -637,25 +820,63 @@ async def _run_chain(
chain_name: str, question: str, context: dict[str, Any] | None, client_info: Any
):
"""Execute a chain and return normalized payload."""
conn = connect_db()
try:
normalized_context = _normalize_chain_context(context, conn)
finally:
conn.close()

chain = _build_chain(chain_name, client_info)
if hasattr(chain, "run"):
result = chain.run(question=question, context=context)
elif hasattr(chain, "invoke"):
result = chain.invoke({"question": question, "context": context or {}})
else:
raise RuntimeError(f"Chain '{chain_name}' does not expose run/invoke")
try:
if chain_name == "filing_summary":
filing_id = normalized_context.get("filing_id")
try:
if filing_id is None:
raise TypeError("missing filing_id")
result = chain.run(int(cast(Any, filing_id)))
except (TypeError, ValueError) as exc:
Comment on lines +831 to +837
raise HTTPException(
status_code=400,
detail="filing_summary requires context.filing_id",
) from exc
result = _format_filing_summary_payload(result)
elif chain_name == "holdings_analysis":
result = chain.run(
question,
manager_ids=_normalize_manager_ids(normalized_context.get("manager_ids")) or None,
cusips=_normalize_cusips(normalized_context.get("cusips")) or None,
date_range=_parse_date_range(normalized_context.get("date_range")),
)
Comment on lines +843 to +849
result = _format_holdings_analysis_payload(result)
elif hasattr(chain, "run"):
result = chain.run(question=question, context=normalized_context or None)
elif hasattr(chain, "invoke"):
result = chain.invoke({"question": question, "context": normalized_context or {}})
else:
raise RuntimeError(f"Chain '{chain_name}' does not expose run/invoke")
finally:
close = getattr(chain, "db", None)
try:
if close is not None and hasattr(close, "close"):
close.close()
except Exception:
pass

if asyncio.iscoroutine(result):
result = await result

return _extract_chain_payload(result)


def _resolve_chain_name(requested_chain: str, question: str) -> str:
def _resolve_chain_name(
requested_chain: str, question: str, context: dict[str, Any] | None = None
) -> str:
"""Resolve requested/auto chain to a concrete chain implementation name."""
if requested_chain != "auto":
return requested_chain
classified_chain = _classify_intent(question)
if classified_chain == "filing_summary" and not (context or {}).get("filing_id"):
return "rag_search"
if classified_chain in DIRECT_CHAIN_PATHS:
return classified_chain
return "rag_search"
Expand All @@ -677,7 +898,7 @@ async def chat_api(request: ChatRequest, raw_request: Request) -> ChatResponse:
requested_chain = (request.chain or "auto").strip().lower()
if requested_chain not in VALID_CHAIN_NAMES:
raise HTTPException(status_code=400, detail="Invalid chain")
chain_used = _resolve_chain_name(requested_chain, request.question)
chain_used = _resolve_chain_name(requested_chain, request.question, request.context)

answer, sources, sql, trace_url = await _run_chain(
chain_used, request.question, request.context, client_info
Expand Down
47 changes: 43 additions & 4 deletions chains/nl_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,19 @@ def __init__(self, llm: Any | None = None, db_conn: Any | None = None) -> None:
self._schema_ddl = self._load_schema_ddl()
self._known_tables = self._extract_known_tables(self._schema_ddl)

def _guard_context(self, context: dict[str, Any] | None) -> None:
if not context:
return
for value in context.values():
if isinstance(value, str):
guard_input(value)
elif isinstance(value, dict):
self._guard_context(value)
elif isinstance(value, list):
for item in value:
if isinstance(item, str):
guard_input(item)

def _load_schema_ddl(self) -> str:
"""Load public CREATE TABLE statements and omit internal bookkeeping tables."""
schema_path = Path(__file__).resolve().parents[1] / "schema.sql"
Expand Down Expand Up @@ -164,9 +177,35 @@ def _format_results(self, results: list[dict[str, Any]], question: str) -> str:
f"Columns: {column_summary}. Sample rows: {json.dumps(sample, default=str)}"
)

def _prompt_text(self, question: str) -> str:
def _context_prompt(self, context: dict[str, Any] | None) -> str:
if not context:
return ""
filters: list[str] = []
manager_ids = context.get("manager_ids")
if manager_ids:
filters.append(f"- Restrict results to manager_ids={manager_ids}")
manager_name = context.get("manager_name")
if manager_name:
filters.append(f"- Restrict results to manager_name={manager_name}")
filing_id = context.get("filing_id")
if filing_id:
filters.append(f"- Restrict results to filing_id={filing_id}")
date_range = context.get("date_range")
if date_range:
filters.append(f"- Restrict results to date_range={date_range}")
cusips = context.get("cusips")
if cusips:
filters.append(f"- Restrict results to cusips={cusips}")
if not filters:
return ""
return "Context filters:\n" + "\n".join(filters) + "\n"

def _prompt_text(self, question: str, context: dict[str, Any] | None = None) -> str:
return (
NL_QUERY_SYSTEM_PROMPT.format(schema_ddl=self._schema_ddl) + f"\nQuestion: {question}\n"
NL_QUERY_SYSTEM_PROMPT.format(schema_ddl=self._schema_ddl)
+ "\n"
+ self._context_prompt(context)
+ f"Question: {question}\n"
)

def _invoke_llm(self, prompt: str) -> tuple[str, str | None]:
Expand Down Expand Up @@ -204,9 +243,9 @@ def _parse_llm_result(self, raw_response: str) -> NLQueryResult:

def run(self, question: str, context: dict[str, Any] | None = None) -> dict[str, Any]:
"""Run the NL-to-SQL pipeline end-to-end."""
del context # Reserved for future filters.
guard_input(question)
prompt = self._prompt_text(question)
self._guard_context(context)
prompt = self._prompt_text(question, context)
raw_response, trace_url = self._invoke_llm(prompt)
parsed = self._parse_llm_result(raw_response)
normalized_sql = self._normalize_sql(parsed.sql)
Expand Down
Loading
Loading