diff --git a/CHANGELOG.md b/CHANGELOG.md index 0e6f787..ebf3415 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Fixed +- `_LazyStore` thread safety — added double-checked locking to prevent duplicate `PostgresStore`/connection pool creation under concurrent access from embedding workers, cleanup thread, or parallel requests ([#164](https://github.com/cmeans/mcp-awareness/issues/164)) + ## [0.16.1] - 2026-04-09 ### Fixed diff --git a/README.md b/README.md index d1b5a7c..3937fb5 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ > **Your AI's memory shouldn't be locked to one app. It should follow you everywhere.** > [!NOTE] -> Early-stage but actively deployed — 725 tests, 16 releases, in daily use across Claude.ai, Claude Code, and Claude Desktop. See [Current status](#current-status) for what's working and what's planned. +> Early-stage but actively deployed — 729 tests, 16 releases, in daily use across Claude.ai, Claude Code, and Claude Desktop. See [Current status](#current-status) for what's working and what's planned. ## What this is @@ -396,7 +396,7 @@ For single-user deployments, secret path + WAF is sufficient. For multi-user, en - Secret path auth + Cloudflare WAF for edge-level access control - Docker Compose with Postgres, optional Ollama, named Cloudflare Tunnel, or ephemeral quick tunnel - Request timing instrumentation and `/health` endpoint -- 725 tests (all against real Postgres + Ollama in CI), strict type checking, CI pipeline with coverage, QA gate +- 729 tests (all against real Postgres + Ollama in CI), strict type checking, CI pipeline with coverage, QA gate ### Not yet implemented - Layer 2 (baseline) detection — rolling averages and deviation calculation diff --git a/src/mcp_awareness/server.py b/src/mcp_awareness/server.py index e05bace..d2ddcf9 100644 --- a/src/mcp_awareness/server.py +++ b/src/mcp_awareness/server.py @@ -32,6 +32,7 @@ import os import pathlib import re +import threading import time from datetime import datetime, timezone from typing import Any, Literal @@ -174,10 +175,16 @@ class _LazyStore: """ _instance: Store | None = None + _lock: threading.Lock = threading.Lock() def __getattr__(self, name: str) -> Any: + # Double-checked locking: safe in CPython — GIL ensures atomic + # reference assignment, so the outer check never sees a + # partially-constructed object. if _LazyStore._instance is None: - _LazyStore._instance = _create_store() + with _LazyStore._lock: + if _LazyStore._instance is None: + _LazyStore._instance = _create_store() return getattr(_LazyStore._instance, name) diff --git a/tests/test_server.py b/tests/test_server.py index 8521909..f4fd2a3 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -20,6 +20,7 @@ import json import os +import threading import pytest from mcp.server.fastmcp.exceptions import ToolError @@ -81,6 +82,195 @@ def test_raises_without_database_url(self, monkeypatch: pytest.MonkeyPatch) -> N server_mod._create_store() +class TestLazyStoreThreadSafety: + """Verify _LazyStore only creates one store under concurrent access.""" + + def test_concurrent_access_creates_single_instance( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + """Multiple threads racing through __getattr__ must produce exactly one store.""" + call_count = 0 + lock = threading.Lock() + barrier = threading.Barrier(10) + + class _FakeStore: + def ping(self) -> str: + return "ok" + + def _counting_create_store() -> _FakeStore: + nonlocal call_count + with lock: + call_count += 1 + return _FakeStore() + + monkeypatch.setattr(server_mod, "_create_store", _counting_create_store) + + lazy = server_mod._LazyStore() + server_mod._LazyStore._instance = None + errors: list[Exception] = [] + + def _access() -> None: + try: + barrier.wait() # all threads launch simultaneously + lazy.ping() + except Exception as exc: + errors.append(exc) + + threads = [threading.Thread(target=_access) for _ in range(10)] + for t in threads: + t.start() + for t in threads: + t.join(timeout=5) + + # Restore to avoid leaking into other tests + server_mod._LazyStore._instance = None + + assert not errors, f"Threads raised: {errors}" + assert call_count == 1, f"_create_store called {call_count} times, expected 1" + + def test_cleanup_thread_and_request_handler_race( + self, store: Store, pg_dsn: str, monkeypatch: pytest.MonkeyPatch + ) -> None: + """Cleanup daemon thread and request handler must share a single store.""" + monkeypatch.setenv("AWARENESS_DATABASE_URL", pg_dsn) + call_count = 0 + count_lock = threading.Lock() + go = threading.Event() + + real_create = server_mod._create_store + + def _tracking_create_store() -> Store: + nonlocal call_count + with count_lock: + call_count += 1 + return real_create() + + monkeypatch.setattr(server_mod, "_create_store", _tracking_create_store) + server_mod._LazyStore._instance = None + lazy = server_mod._LazyStore() + errors: list[Exception] = [] + + def _simulate_cleanup() -> None: + """Simulates a daemon thread (like cleanup) accessing the store.""" + try: + go.wait() + _ = lazy.get_stats + except Exception as exc: + errors.append(exc) + + def _simulate_request() -> None: + """Simulates a request thread accessing the store.""" + try: + go.wait() + _ = lazy.add + except Exception as exc: + errors.append(exc) + + t_cleanup = threading.Thread(target=_simulate_cleanup, daemon=True) + t_request = threading.Thread(target=_simulate_request) + t_cleanup.start() + t_request.start() + go.set() + t_cleanup.join(timeout=5) + t_request.join(timeout=5) + + server_mod._LazyStore._instance = None + + assert not errors, f"Threads raised: {errors}" + assert call_count == 1, f"_create_store called {call_count} times, expected 1" + + def test_embedding_worker_and_request_handler_race( + self, store: Store, pg_dsn: str, monkeypatch: pytest.MonkeyPatch + ) -> None: + """Embedding thread pool worker and request handler must share a single store.""" + monkeypatch.setenv("AWARENESS_DATABASE_URL", pg_dsn) + call_count = 0 + count_lock = threading.Lock() + go = threading.Event() + + real_create = server_mod._create_store + + def _tracking_create_store() -> Store: + nonlocal call_count + with count_lock: + call_count += 1 + return real_create() + + monkeypatch.setattr(server_mod, "_create_store", _tracking_create_store) + server_mod._LazyStore._instance = None + lazy = server_mod._LazyStore() + errors: list[Exception] = [] + + def _simulate_embedding() -> None: + """Simulates an embedding worker accessing store.upsert_embedding.""" + try: + go.wait() + _ = lazy.upsert_embedding + except Exception as exc: + errors.append(exc) + + def _simulate_request() -> None: + """Simulates a request thread accessing store.get_knowledge.""" + try: + go.wait() + _ = lazy.get_knowledge + except Exception as exc: + errors.append(exc) + + t_embed = threading.Thread(target=_simulate_embedding, name="embed-0") + t_request = threading.Thread(target=_simulate_request) + t_embed.start() + t_request.start() + go.set() + t_embed.join(timeout=5) + t_request.join(timeout=5) + + server_mod._LazyStore._instance = None + + assert not errors, f"Threads raised: {errors}" + assert call_count == 1, f"_create_store called {call_count} times, expected 1" + + def test_concurrent_real_postgres_store_creation( + self, pg_dsn: str, monkeypatch: pytest.MonkeyPatch + ) -> None: + """Integration: 10 threads racing to create a real PostgresStore + pool.""" + monkeypatch.setenv("AWARENESS_DATABASE_URL", pg_dsn) + call_count = 0 + count_lock = threading.Lock() + barrier = threading.Barrier(10) + + real_create = server_mod._create_store + + def _tracking_create_store() -> Store: + nonlocal call_count + with count_lock: + call_count += 1 + return real_create() + + monkeypatch.setattr(server_mod, "_create_store", _tracking_create_store) + server_mod._LazyStore._instance = None + lazy = server_mod._LazyStore() + errors: list[Exception] = [] + + def _access() -> None: + try: + barrier.wait() # synchronize start, not factory entry + _ = lazy.add # triggers __getattr__ → _create_store + except Exception as exc: + errors.append(exc) + + threads = [threading.Thread(target=_access) for _ in range(10)] + for t in threads: + t.start() + for t in threads: + t.join(timeout=10) + + server_mod._LazyStore._instance = None + + assert not errors, f"Threads raised: {errors}" + assert call_count == 1, f"_create_store called {call_count} times, expected 1" + + # --------------------------------------------------------------------------- # Resource tests # ---------------------------------------------------------------------------