diff --git a/CHANGELOG.md b/CHANGELOG.md index d41b16c..ecf91f7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **Hash stability**: documented embedding hash behavior in `embeddings.py` module docstring and function docstrings — explains that changes to `compose_embedding_text()` invalidate all stored hashes and trigger mass re-embedding (MEDIUM #22) - **Data dictionary**: expanded `text_hash` column description in `docs/data-dictionary.md` to explain staleness detection and mass re-embedding on composition changes +### Changed +- **`upsert_by_logical_key` single-connection refactor**: the INSERT, existing-row fetch, and conditional UPDATE now share a single pooled connection and transaction instead of acquiring up to 3 separate connections, reducing pool contention under concurrency (MEDIUM #2) + ### Added - **JWKS auto-discovery**: when `AWARENESS_OAUTH_JWKS_URI` is not set, the server now fetches `/.well-known/openid-configuration` to discover the correct `jwks_uri` before falling back to `/.well-known/jwks.json` — fixes WorkOS compatibility (#126) - **OAuth user profile enrichment**: email and display_name populated from token claims on subsequent logins if missing diff --git a/src/mcp_awareness/postgres_store.py b/src/mcp_awareness/postgres_store.py index 667ffd2..fb1e778 100644 --- a/src/mcp_awareness/postgres_store.py +++ b/src/mcp_awareness/postgres_store.py @@ -621,12 +621,13 @@ def upsert_by_logical_key( ) -> tuple[Entry, bool]: """Upsert by source + logical_key. Returns (entry, created). - Uses a single connection with INSERT ... ON CONFLICT to avoid race - conditions when concurrent writers target the same logical_key. + Uses a single connection for the entire operation: INSERT attempt, + existing-row fetch, and conditional update all share one connection + and transaction to avoid pool contention under concurrency. """ with self._pool.connection() as conn, conn.transaction(), conn.cursor() as cur: self._set_rls_context(cur, owner_id) - # Attempt insert; on conflict, fetch the existing row's id + # Attempt insert; on conflict, return inserted=false cur.execute( _load_sql("upsert_by_logical_key"), ( @@ -646,27 +647,63 @@ def upsert_by_logical_key( assert row is not None inserted: bool = row["inserted"] - if inserted: - self._cleanup_expired() - return (entry, True) + if inserted: + self._cleanup_expired() + return (entry, True) - # Existing entry — compute diff and update if needed - existing = self._query_entries( - owner_id, "source = %s AND logical_key = %s", (source, logical_key) - ) - old = existing[0] - updates: dict[str, Any] = {} - if entry.tags != old.tags: - updates["tags"] = entry.tags - for field in ("description", "content", "content_type"): - new_val = entry.data.get(field) - old_val = old.data.get(field) - if new_val is not None and new_val != old_val: - updates[field] = new_val - if updates: - result = self.update_entry(owner_id, old.id, updates) - return (result or old, False) - return (old, False) + # Existing entry — fetch within the same connection + query_sql = _load_sql("query_entries").format( + where="source = %s AND logical_key = %s", + order_by="COALESCE(updated, created) DESC", + limit_clause="", + ) + cur.execute(query_sql, (owner_id, source, logical_key)) + rows = cur.fetchall() + old = self._row_to_entry(rows[0]) + + # Compute diff + updates: dict[str, Any] = {} + if entry.tags != old.tags: + updates["tags"] = entry.tags + for field in ("description", "content", "content_type"): + new_val = entry.data.get(field) + old_val = old.data.get(field) + if new_val is not None and new_val != old_val: + updates[field] = new_val + + if not updates: + return (old, False) + + # Apply updates inline (mirrors update_entry logic for knowledge types) + now = now_utc() + changed: dict[str, Any] = {} + if "tags" in updates and updates["tags"] != old.tags: + changed["tags"] = old.tags + old.tags = updates["tags"] + for field in ("description", "content", "content_type"): + if field in updates and updates[field] != old.data.get(field): + old_val = old.data.get(field) + if old_val is not None: + changed[field] = old_val + old.data[field] = updates[field] + + if changed: + changelog = old.data.setdefault("changelog", []) + changelog.append({"updated": to_iso(now), "changed": changed}) + old.updated = now + cur.execute( + _load_sql("update_entry"), + ( + now, + old.source, + json.dumps(old.tags), + json.dumps(old.data), + old.id, + owner_id, + ), + ) + + return (old, False) def get_stats(self, owner_id: str) -> dict[str, Any]: """Get entry counts by type, list of sources, and total count.""" diff --git a/tests/test_store.py b/tests/test_store.py index 5afdc0f..14e4746 100644 --- a/tests/test_store.py +++ b/tests/test_store.py @@ -2434,6 +2434,55 @@ def slow_cleanup() -> None: if store._cleanup_thread is not None: store._cleanup_thread.join(timeout=2) + def test_upsert_by_logical_key_creates_then_updates(self, store): + """First upsert creates the entry; second upsert updates it in place.""" + source = "upsert-single-conn" + logical_key = "lk-create-update" + now = now_utc() + + # First call — should create + entry1 = Entry( + id=make_id(), + type=EntryType.NOTE, + source=source, + tags=["v1"], + created=now, + expires=None, + data={"description": "original description"}, + logical_key=logical_key, + ) + result1, created1 = store.upsert_by_logical_key(TEST_OWNER, source, logical_key, entry1) + assert created1 is True + assert result1.id == entry1.id + assert result1.data["description"] == "original description" + + # Second call — same logical_key, different tags and description + entry2 = Entry( + id=make_id(), + type=EntryType.NOTE, + source=source, + tags=["v2"], + created=now, + expires=None, + data={"description": "updated description"}, + logical_key=logical_key, + ) + result2, created2 = store.upsert_by_logical_key(TEST_OWNER, source, logical_key, entry2) + assert created2 is False + # Should keep the original entry's id + assert result2.id == entry1.id + # Tags and description should reflect the update + assert result2.tags == ["v2"] + assert result2.data["description"] == "updated description" + # Changelog should record the change + assert "changelog" in result2.data + assert len(result2.data["changelog"]) >= 1 + + # Verify only one entry exists for this logical_key + results = store.get_knowledge(TEST_OWNER, source=source) + matching = [e for e in results if e.logical_key == logical_key] + assert len(matching) == 1 + def test_concurrent_upsert_by_logical_key(self, store): """Concurrent upserts with same source + logical_key must not create duplicates.""" source = "upsert-race"