diff --git a/CHANGELOG.md b/CHANGELOG.md
index 443eedd..779e817 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -8,6 +8,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]
### Added
+- **OAuth 2.1 resource server**: provider-agnostic JWKS-based token validation for external OAuth providers (WorkOS, Auth0, Cloudflare Access, Keycloak, etc.)
+- **Dual auth**: self-signed JWTs (via CLI) and OAuth provider tokens both accepted — OAuth for interactive clients, self-signed for edge providers/scripts
+- **User auto-provisioning**: auto-create user record on first valid OAuth login (`AWARENESS_OAUTH_AUTO_PROVISION`, default: false for tighter control during early access)
+- **Well-known metadata**: `/.well-known/oauth-protected-resource` (RFC 9728) for OAuth discovery by MCP clients
+- **OAuth env vars**: `AWARENESS_OAUTH_ISSUER`, `AWARENESS_OAUTH_AUDIENCE`, `AWARENESS_OAUTH_JWKS_URI`, `AWARENESS_OAUTH_USER_CLAIM`, `AWARENESS_OAUTH_AUTO_PROVISION`
- **JWT auth middleware**: opt-in via `AWARENESS_AUTH_REQUIRED=true`, validates Bearer tokens, extracts owner_id from `sub` claim
- **Row-level security**: Postgres RLS policies on all data tables as defense-in-depth alongside application-level owner_id filtering
- **CLI: `mcp-awareness-user`**: add/list/set-password/export/delete users with email normalization, E.164 phone validation, argon2id password hashing
diff --git a/alembic/versions/i4d5e6f7g8h9_add_oauth_columns.py b/alembic/versions/i4d5e6f7g8h9_add_oauth_columns.py
new file mode 100644
index 0000000..ad5a757
--- /dev/null
+++ b/alembic/versions/i4d5e6f7g8h9_add_oauth_columns.py
@@ -0,0 +1,58 @@
+# mcp-awareness — ambient system awareness for AI agents
+# Copyright (C) 2026 Chris Means
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+"""add OAuth identity columns to users table
+
+Revision ID: i4d5e6f7g8h9
+Revises: h3c4d5e6f7g8
+Create Date: 2026-03-29 20:00:00.000000
+
+"""
+
+from __future__ import annotations
+
+from collections.abc import Sequence
+
+from alembic import op
+
+# revision identifiers, used by Alembic.
+revision: str = "i4d5e6f7g8h9"
+down_revision: str | Sequence[str] | None = "h3c4d5e6f7g8"
+branch_labels: str | Sequence[str] | None = None
+depends_on: str | Sequence[str] | None = None
+
+
+def upgrade() -> None:
+ # OAuth identity: sub claim + issuer for provider-agnostic lookup
+ op.execute("ALTER TABLE users ADD COLUMN IF NOT EXISTS oauth_subject TEXT")
+ op.execute("ALTER TABLE users ADD COLUMN IF NOT EXISTS oauth_issuer TEXT")
+ # Unique constraint: one OAuth identity per provider per user
+ op.execute(
+ "CREATE UNIQUE INDEX IF NOT EXISTS ix_users_oauth_identity "
+ "ON users (oauth_issuer, oauth_subject) WHERE oauth_issuer IS NOT NULL"
+ )
+ # Fast lookup index for every authenticated request
+ op.execute(
+ "CREATE INDEX IF NOT EXISTS ix_users_oauth_subject "
+ "ON users (oauth_subject) WHERE oauth_subject IS NOT NULL"
+ )
+
+
+def downgrade() -> None:
+ op.execute("DROP INDEX IF EXISTS ix_users_oauth_subject")
+ op.execute("DROP INDEX IF EXISTS ix_users_oauth_identity")
+ op.execute("ALTER TABLE users DROP COLUMN IF EXISTS oauth_issuer")
+ op.execute("ALTER TABLE users DROP COLUMN IF EXISTS oauth_subject")
diff --git a/pyproject.toml b/pyproject.toml
index c45025d..9bbfe73 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -11,7 +11,7 @@ dependencies = [
"psycopg_pool>=3.2",
"alembic>=1.13",
"sqlalchemy>=2.0",
- "PyJWT>=2.8",
+ "PyJWT[crypto]>=2.8",
"argon2-cffi>=23.1",
"phonenumbers>=8.13",
"zxcvbn>=4.4",
diff --git a/src/mcp_awareness/middleware.py b/src/mcp_awareness/middleware.py
index 7d32d09..40e6805 100644
--- a/src/mcp_awareness/middleware.py
+++ b/src/mcp_awareness/middleware.py
@@ -93,13 +93,61 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await self.app(scope, receive, send)
+class WellKnownMiddleware:
+ """Serve /.well-known/oauth-protected-resource (RFC 9728)."""
+
+ def __init__(
+ self,
+ app: ASGIApp,
+ oauth_issuer: str,
+ host: str,
+ port: int,
+ mount_path: str = "",
+ ) -> None:
+ self.app = app
+ self.oauth_issuer = oauth_issuer.rstrip("/")
+ self.host = host
+ self.port = port
+ self.mount_path = mount_path
+
+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
+ if scope["type"] == "http":
+ path = scope.get("path", "")
+ if path == "/.well-known/oauth-protected-resource":
+ metadata = {
+ "resource": f"https://{self.host}:{self.port}{self.mount_path}/mcp",
+ "authorization_servers": [self.oauth_issuer],
+ "token_methods": ["Bearer"],
+ }
+ resp = JSONResponse(metadata)
+ await resp(scope, receive, send)
+ return
+ await self.app(scope, receive, send)
+
+
class AuthMiddleware:
- """Validate JWT Bearer token and set owner context."""
+ """Validate JWT Bearer token and set owner context.
- def __init__(self, app: ASGIApp, jwt_secret: str, algorithm: str = "HS256") -> None:
+ Supports dual auth: self-signed JWTs (via shared secret) and OAuth provider
+ tokens (via JWKS). Self-signed is tried first; if it fails and an OAuth
+ validator is configured, the token is validated against the provider's keys.
+ """
+
+ def __init__(
+ self,
+ app: ASGIApp,
+ jwt_secret: str,
+ algorithm: str = "HS256",
+ oauth_validator: object | None = None,
+ auto_provision: bool = True,
+ resource_metadata_url: str = "",
+ ) -> None:
self.app = app
self.jwt_secret = jwt_secret
self.algorithm = algorithm
+ self.oauth_validator = oauth_validator
+ self.auto_provision = auto_provision
+ self.resource_metadata_url = resource_metadata_url
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
@@ -107,8 +155,8 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
return
path = scope.get("path", "")
- # Skip auth for health, favicon, and non-MCP paths
- if path in ("/health", "/favicon.ico"):
+ # Skip auth for health, favicon, well-known, and non-MCP paths
+ if path in ("/health", "/favicon.ico") or path.startswith("/.well-known/"):
await self.app(scope, receive, send)
return
@@ -116,32 +164,23 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
headers = dict(scope.get("headers", []))
auth_header = headers.get(b"authorization", b"").decode()
if not auth_header.startswith("Bearer "):
- resp = JSONResponse(
- {"error": "Missing or invalid Authorization header"}, status_code=401
- )
+ resp = self._unauthorized("Missing or invalid Authorization header")
await resp(scope, receive, send)
return
token = auth_header[7:] # Strip "Bearer "
- try:
- import jwt
- payload = jwt.decode(token, self.jwt_secret, algorithms=[self.algorithm])
- owner_id: str | None = payload.get("sub")
- if not owner_id:
- resp = JSONResponse({"error": "JWT missing 'sub' claim"}, status_code=401)
- await resp(scope, receive, send)
- return
- except Exception as exc:
- # Handle both ExpiredSignatureError and InvalidTokenError
- import jwt as jwt_mod
+ # Try self-signed JWT first
+ owner_id, error = self._try_self_signed(token)
- if isinstance(exc, jwt_mod.ExpiredSignatureError):
- resp = JSONResponse({"error": "Token expired"}, status_code=401)
- elif isinstance(exc, jwt_mod.InvalidTokenError):
- resp = JSONResponse({"error": "Invalid token"}, status_code=401)
- else:
- raise
+ # Fall back to OAuth provider validation
+ if owner_id is None and self.oauth_validator is not None:
+ owner_id = await self._try_oauth(token)
+ if owner_id is not None:
+ error = None
+
+ if owner_id is None:
+ resp = self._unauthorized(error or "Invalid token")
await resp(scope, receive, send)
return
@@ -153,3 +192,103 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await self.app(scope, receive, send)
finally:
_owner_ctx.reset(token_reset)
+
+ def _try_self_signed(self, token: str) -> tuple[str | None, str | None]:
+ """Validate a self-signed JWT (from mcp-awareness-token CLI).
+
+ Returns (owner_id, error_message). On success error is None.
+ On failure owner_id is None and error carries the reason.
+ """
+ if not self.jwt_secret:
+ return None, None
+ try:
+ import jwt
+
+ payload = jwt.decode(token, self.jwt_secret, algorithms=[self.algorithm])
+ owner_id: str | None = payload.get("sub")
+ if owner_id:
+ return owner_id, None
+ return None, "JWT missing 'sub' claim"
+ except Exception as exc:
+ import jwt as jwt_mod
+
+ if isinstance(exc, jwt_mod.ExpiredSignatureError):
+ return None, "Token expired"
+ if isinstance(exc, jwt_mod.InvalidTokenError):
+ return None, "Invalid token"
+ return None, None
+
+ async def _try_oauth(self, token: str) -> str | None:
+ """Validate an OAuth token against the external provider's JWKS."""
+ from .oauth import OAuthTokenValidator
+
+ validator: OAuthTokenValidator = self.oauth_validator # type: ignore[assignment]
+ try:
+ claims = validator.validate(token)
+ except Exception:
+ return None
+
+ owner_id = claims["owner_id"]
+ oauth_subject = claims.get("oauth_subject")
+ oauth_issuer = claims.get("oauth_issuer")
+ email = claims.get("email")
+
+ # Resolve user identity: OAuth lookup → email link → auto-provision
+ resolved_id = self._resolve_user(
+ owner_id, email, claims.get("name"), oauth_subject, oauth_issuer
+ )
+
+ return resolved_id or owner_id
+
+ def _resolve_user(
+ self,
+ owner_id: str,
+ email: str | None,
+ display_name: str | None,
+ oauth_subject: str | None,
+ oauth_issuer: str | None,
+ ) -> str | None:
+ """Resolve OAuth token to a local user, linking or creating as needed.
+
+ Resolution order:
+ 1. Look up by OAuth identity (issuer + subject) — already linked user
+ 2. If email present, try to link to a pre-provisioned user by email
+ 3. If auto_provision enabled, create a new user
+ 4. Otherwise return None (use owner_id from token as-is)
+ """
+ try:
+ from .server import store
+
+ # 1. Already linked?
+ if oauth_issuer and oauth_subject:
+ existing = store.get_user_by_oauth(oauth_issuer, oauth_subject)
+ if existing:
+ return str(existing["id"])
+
+ # 2. Pre-provisioned user with matching email? Link on first login.
+ if email and oauth_subject and oauth_issuer:
+ linked_id = store.link_oauth_identity(oauth_subject, oauth_issuer, email)
+ if linked_id:
+ return str(linked_id)
+
+ # 3. Auto-provision new user
+ if self.auto_provision:
+ store.create_user_if_not_exists(
+ owner_id, email, display_name, oauth_subject, oauth_issuer
+ )
+ return owner_id
+
+ except Exception:
+ # Don't fail the request if user resolution fails
+ pass
+
+ return None
+
+ def _unauthorized(self, message: str) -> JSONResponse:
+ """Build a 401 response with proper WWW-Authenticate header."""
+ headers: dict[str, str] = {}
+ if self.resource_metadata_url:
+ headers["WWW-Authenticate"] = f'Bearer resource_metadata="{self.resource_metadata_url}"'
+ else:
+ headers["WWW-Authenticate"] = "Bearer"
+ return JSONResponse({"error": message}, status_code=401, headers=headers)
diff --git a/src/mcp_awareness/oauth.py b/src/mcp_awareness/oauth.py
new file mode 100644
index 0000000..7df23a8
--- /dev/null
+++ b/src/mcp_awareness/oauth.py
@@ -0,0 +1,103 @@
+# mcp-awareness — ambient system awareness for AI agents
+# Copyright (C) 2026 Chris Means
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+"""OAuth 2.1 resource server — JWKS-based token validation for external providers."""
+
+from __future__ import annotations
+
+import time
+
+import jwt
+from jwt import PyJWKClient
+
+
+class OAuthTokenValidator:
+ """Validates OAuth access tokens (JWTs) against an external provider's JWKS.
+
+ Provider-agnostic: works with any OIDC-compliant provider (WorkOS, Auth0,
+ Cloudflare Access, Keycloak, AWS Cognito, etc.).
+ """
+
+ def __init__(
+ self,
+ issuer: str,
+ audience: str = "",
+ jwks_uri: str = "",
+ user_claim: str = "sub",
+ jwks_cache_ttl: int = 3600,
+ ) -> None:
+ self.issuer = issuer.rstrip("/")
+ self.audience = audience
+ self.user_claim = user_claim
+
+ # JWKS URI: explicit override or derive from issuer
+ if jwks_uri:
+ self._jwks_uri = jwks_uri
+ else:
+ self._jwks_uri = f"{self.issuer}/.well-known/jwks.json"
+
+ self._jwk_client = PyJWKClient(self._jwks_uri, cache_jwk_set=True)
+ self._jwks_cache_ttl = jwks_cache_ttl
+ self._last_jwks_fetch: float = 0.0
+
+ def validate(self, token: str) -> dict[str, str]:
+ """Validate an OAuth JWT and return extracted identity claims.
+
+ Returns:
+ dict with 'owner_id' and optional 'email', 'name' keys.
+
+ Raises:
+ jwt.InvalidTokenError: if token is invalid, expired, or unverifiable.
+ """
+ # Refresh JWKS cache if stale
+ now = time.monotonic()
+ if now - self._last_jwks_fetch > self._jwks_cache_ttl:
+ self._jwk_client = PyJWKClient(self._jwks_uri, cache_jwk_set=True)
+ self._last_jwks_fetch = now
+
+ signing_key = self._jwk_client.get_signing_key_from_jwt(token)
+
+ # Build kwargs for jwt.decode
+ kwargs: dict[str, object] = {
+ "key": signing_key.key,
+ "algorithms": ["RS256", "ES256"],
+ "issuer": self.issuer,
+ }
+ if self.audience:
+ kwargs["audience"] = self.audience
+ else:
+ kwargs["options"] = {"verify_aud": False}
+
+ payload = jwt.decode(token, **kwargs) # type: ignore[arg-type]
+
+ # Extract owner_id from configured claim
+ owner_id = payload.get(self.user_claim)
+ if not owner_id:
+ raise jwt.InvalidTokenError(f"Token missing '{self.user_claim}' claim")
+
+ result: dict[str, str] = {"owner_id": str(owner_id)}
+
+ # Extract identity fields for auto-provisioning and lookup
+ if "sub" in payload:
+ result["oauth_subject"] = str(payload["sub"])
+ if "iss" in payload:
+ result["oauth_issuer"] = str(payload["iss"])
+ if "email" in payload:
+ result["email"] = str(payload["email"])
+ if "name" in payload:
+ result["name"] = str(payload["name"])
+
+ return result
diff --git a/src/mcp_awareness/postgres_store.py b/src/mcp_awareness/postgres_store.py
index 9afb0b2..d96cbf5 100644
--- a/src/mcp_awareness/postgres_store.py
+++ b/src/mcp_awareness/postgres_store.py
@@ -1171,6 +1171,52 @@ def get_referencing_entries(self, owner_id: str, entry_id: str) -> list[Entry]:
(json.dumps([entry_id]),),
)
+ # ------------------------------------------------------------------
+ # User operations (for OAuth auto-provisioning)
+ # ------------------------------------------------------------------
+
+ def get_user(self, user_id: str) -> dict[str, Any] | None:
+ """Look up a user by ID. Returns dict or None if not found."""
+ with self._pool.connection() as conn, conn.cursor(row_factory=dict_row) as cur:
+ cur.execute(_load_sql("get_user"), (user_id,))
+ return cur.fetchone()
+
+ def create_user_if_not_exists(
+ self,
+ user_id: str,
+ email: str | None = None,
+ display_name: str | None = None,
+ oauth_subject: str | None = None,
+ oauth_issuer: str | None = None,
+ ) -> None:
+ """Auto-provision a user on first OAuth login. No-op if user exists."""
+ with self._pool.connection() as conn, conn.transaction(), conn.cursor() as cur:
+ cur.execute(
+ _load_sql("create_user_auto"),
+ (user_id, email, display_name, oauth_subject, oauth_issuer),
+ )
+
+ def get_user_by_oauth(self, oauth_issuer: str, oauth_subject: str) -> dict[str, Any] | None:
+ """Look up a user by OAuth identity. Returns dict or None."""
+ with self._pool.connection() as conn, conn.cursor(row_factory=dict_row) as cur:
+ cur.execute(_load_sql("get_user_by_oauth"), (oauth_issuer, oauth_subject))
+ return cur.fetchone()
+
+ def link_oauth_identity(self, oauth_subject: str, oauth_issuer: str, email: str) -> str | None:
+ """Link an OAuth identity to a pre-provisioned user matched by email.
+
+ Returns the user ID if linked, None if no matching user found.
+ Only links if the user's oauth_subject is currently NULL (first-time link).
+ """
+ with (
+ self._pool.connection() as conn,
+ conn.transaction(),
+ conn.cursor(row_factory=dict_row) as cur,
+ ):
+ cur.execute(_load_sql("link_oauth_identity"), (oauth_subject, oauth_issuer, email))
+ row = cur.fetchone()
+ return str(row["id"]) if row else None
+
def clear(self, owner_id: str) -> None:
with self._pool.connection() as conn, conn.transaction(), conn.cursor() as cur:
self._set_rls_context(cur, owner_id)
diff --git a/src/mcp_awareness/server.py b/src/mcp_awareness/server.py
index 44a47b3..5724ec4 100644
--- a/src/mcp_awareness/server.py
+++ b/src/mcp_awareness/server.py
@@ -80,6 +80,13 @@
JWT_SECRET = os.environ.get("AWARENESS_JWT_SECRET", "")
JWT_ALGORITHM = os.environ.get("AWARENESS_JWT_ALGORITHM", "HS256")
+# OAuth — external provider (WorkOS, Auth0, Cloudflare Access, Keycloak, etc.)
+OAUTH_ISSUER = os.environ.get("AWARENESS_OAUTH_ISSUER", "")
+OAUTH_AUDIENCE = os.environ.get("AWARENESS_OAUTH_AUDIENCE", "")
+OAUTH_JWKS_URI = os.environ.get("AWARENESS_OAUTH_JWKS_URI", "")
+OAUTH_USER_CLAIM = os.environ.get("AWARENESS_OAUTH_USER_CLAIM", "sub")
+OAUTH_AUTO_PROVISION = os.environ.get("AWARENESS_OAUTH_AUTO_PROVISION", "false").lower() == "true"
+
# Embedding provider — optional, configured via env vars
EMBEDDING_PROVIDER = os.environ.get("AWARENESS_EMBEDDING_PROVIDER", "")
EMBEDDING_MODEL = os.environ.get("AWARENESS_EMBEDDING_MODEL", "nomic-embed-text")
@@ -332,24 +339,67 @@ def main() -> None:
print("Shutdown requested — exiting.", flush=True)
+def _build_oauth_validator() -> object | None:
+ """Create an OAuthTokenValidator if an external issuer is configured."""
+ if not OAUTH_ISSUER:
+ return None
+ from mcp_awareness.oauth import OAuthTokenValidator
+
+ return OAuthTokenValidator(
+ issuer=OAUTH_ISSUER,
+ audience=OAUTH_AUDIENCE,
+ jwks_uri=OAUTH_JWKS_URI,
+ user_claim=OAUTH_USER_CLAIM,
+ )
+
+
+def _build_resource_metadata_url() -> str:
+ """Build the well-known resource metadata URL for WWW-Authenticate headers."""
+ if MOUNT_PATH:
+ return f"{MOUNT_PATH}/.well-known/oauth-protected-resource"
+ return "/.well-known/oauth-protected-resource"
+
+
+def _wrap_with_auth(app: Any) -> Any:
+ """Wrap an ASGI app with AuthMiddleware if auth is required."""
+ if not AUTH_REQUIRED:
+ return app
+
+ oauth_validator = _build_oauth_validator()
+ if not JWT_SECRET and not oauth_validator:
+ raise ValueError(
+ "AWARENESS_AUTH_REQUIRED=true requires either "
+ "AWARENESS_JWT_SECRET or AWARENESS_OAUTH_ISSUER (or both)"
+ )
+ from mcp_awareness.middleware import AuthMiddleware
+
+ return AuthMiddleware(
+ app,
+ jwt_secret=JWT_SECRET,
+ algorithm=JWT_ALGORITHM,
+ oauth_validator=oauth_validator,
+ auto_provision=OAUTH_AUTO_PROVISION,
+ resource_metadata_url=_build_resource_metadata_url(),
+ )
+
+
def _run() -> None:
if TRANSPORT == "streamable-http" and MOUNT_PATH:
import uvicorn
from starlette.types import ASGIApp as _ASGIApp
- from mcp_awareness.middleware import SecretPathMiddleware
+ from mcp_awareness.middleware import (
+ SecretPathMiddleware,
+ WellKnownMiddleware,
+ )
inner_app = mcp.streamable_http_app()
app: _ASGIApp = SecretPathMiddleware(inner_app, MOUNT_PATH, _health_response)
- if AUTH_REQUIRED:
- if not JWT_SECRET:
- raise ValueError(
- "AWARENESS_JWT_SECRET is required when AWARENESS_AUTH_REQUIRED=true"
- )
- from mcp_awareness.middleware import AuthMiddleware
+ if OAUTH_ISSUER:
+ app = WellKnownMiddleware(app, OAUTH_ISSUER, HOST, PORT, MOUNT_PATH)
- app = AuthMiddleware(app, JWT_SECRET, JWT_ALGORITHM)
+ app = _wrap_with_auth(app)
config = uvicorn.Config(app, host=HOST, port=PORT)
server = uvicorn.Server(config)
@@ -363,16 +413,14 @@ def _run() -> None:
from mcp_awareness.middleware import HealthMiddleware
inner_app = mcp.streamable_http_app()
- health_app = HealthMiddleware(inner_app, _health_response)
+ health_app: Any = HealthMiddleware(inner_app, _health_response)
+
+ if OAUTH_ISSUER:
+ from mcp_awareness.middleware import WellKnownMiddleware
- if AUTH_REQUIRED:
- if not JWT_SECRET:
- raise ValueError(
- "AWARENESS_JWT_SECRET is required when AWARENESS_AUTH_REQUIRED=true"
- )
- from mcp_awareness.middleware import AuthMiddleware
+ health_app = WellKnownMiddleware(health_app, OAUTH_ISSUER, HOST, PORT, MOUNT_PATH)
- health_app = AuthMiddleware(health_app, JWT_SECRET, JWT_ALGORITHM) # type: ignore[assignment]
+ health_app = _wrap_with_auth(health_app)
config = uvicorn.Config(health_app, host=HOST, port=PORT)
server = uvicorn.Server(config)
diff --git a/src/mcp_awareness/sql/create_tables.sql b/src/mcp_awareness/sql/create_tables.sql
index 18c5267..39ea2d8 100644
--- a/src/mcp_awareness/sql/create_tables.sql
+++ b/src/mcp_awareness/sql/create_tables.sql
@@ -16,10 +16,19 @@ CREATE TABLE IF NOT EXISTS users (
display_name TEXT,
timezone TEXT DEFAULT 'UTC',
preferences JSONB NOT NULL DEFAULT '{{}}'::jsonb,
+ oauth_subject TEXT,
+ oauth_issuer TEXT,
created TIMESTAMPTZ NOT NULL DEFAULT now(),
updated TIMESTAMPTZ,
deleted TIMESTAMPTZ
);
+/* Ensure OAuth columns exist on tables created before migration i4d5e6f7g8h9 */
+ALTER TABLE users ADD COLUMN IF NOT EXISTS oauth_subject TEXT;
+ALTER TABLE users ADD COLUMN IF NOT EXISTS oauth_issuer TEXT;
+CREATE UNIQUE INDEX IF NOT EXISTS ix_users_oauth_identity
+ ON users (oauth_issuer, oauth_subject) WHERE oauth_issuer IS NOT NULL;
+CREATE INDEX IF NOT EXISTS ix_users_oauth_subject
+ ON users (oauth_subject) WHERE oauth_subject IS NOT NULL;
CREATE TABLE IF NOT EXISTS entries (
id TEXT PRIMARY KEY,
diff --git a/src/mcp_awareness/sql/create_user_auto.sql b/src/mcp_awareness/sql/create_user_auto.sql
new file mode 100644
index 0000000..03b1b13
--- /dev/null
+++ b/src/mcp_awareness/sql/create_user_auto.sql
@@ -0,0 +1,7 @@
+/* name: create_user_auto */
+/* mode: literal */
+/* Auto-provision a user on first OAuth login. No-op if user already exists.
+ Params: user_id, email, display_name, oauth_subject, oauth_issuer */
+INSERT INTO users (id, email, display_name, oauth_subject, oauth_issuer, created)
+VALUES (%s, %s, %s, %s, %s, now())
+ON CONFLICT (id) DO NOTHING
diff --git a/src/mcp_awareness/sql/get_user.sql b/src/mcp_awareness/sql/get_user.sql
new file mode 100644
index 0000000..c886110
--- /dev/null
+++ b/src/mcp_awareness/sql/get_user.sql
@@ -0,0 +1,4 @@
+/* name: get_user */
+/* mode: literal */
+/* Get a user by ID. Params: user_id */
+SELECT id, email, display_name, timezone, created FROM users WHERE id = %s AND deleted IS NULL
diff --git a/src/mcp_awareness/sql/get_user_by_oauth.sql b/src/mcp_awareness/sql/get_user_by_oauth.sql
new file mode 100644
index 0000000..5aeba20
--- /dev/null
+++ b/src/mcp_awareness/sql/get_user_by_oauth.sql
@@ -0,0 +1,6 @@
+/* name: get_user_by_oauth */
+/* mode: literal */
+/* Look up a user by OAuth identity (issuer + subject). Params: oauth_issuer, oauth_subject */
+SELECT id, email, display_name, timezone, oauth_subject, oauth_issuer, created
+FROM users
+WHERE oauth_issuer = %s AND oauth_subject = %s AND deleted IS NULL
diff --git a/src/mcp_awareness/sql/link_oauth_identity.sql b/src/mcp_awareness/sql/link_oauth_identity.sql
new file mode 100644
index 0000000..a6505d7
--- /dev/null
+++ b/src/mcp_awareness/sql/link_oauth_identity.sql
@@ -0,0 +1,10 @@
+/* name: link_oauth_identity */
+/* mode: literal */
+/* Link an OAuth identity to an existing user found by email.
+ Sets oauth_subject and oauth_issuer on first OAuth login.
+ Only updates if oauth_subject is currently NULL (first-time link).
+ Params: oauth_subject, oauth_issuer, email */
+UPDATE users
+SET oauth_subject = %s, oauth_issuer = %s, updated = now()
+WHERE email = %s AND oauth_subject IS NULL AND deleted IS NULL
+RETURNING id
diff --git a/tests/test_middleware.py b/tests/test_middleware.py
index 3717c2a..4c80323 100644
--- a/tests/test_middleware.py
+++ b/tests/test_middleware.py
@@ -334,7 +334,7 @@ def test_http_auth_required_no_secret_raises(self, monkeypatch: pytest.MonkeyPat
mock_app = MagicMock()
monkeypatch.setattr(server_mod.mcp, "streamable_http_app", lambda: mock_app)
- with pytest.raises(ValueError, match="AWARENESS_JWT_SECRET is required"):
+ with pytest.raises(ValueError, match="AWARENESS_AUTH_REQUIRED=true requires"):
server_mod._run()
@@ -528,11 +528,18 @@ async def noop_receive() -> dict[str, Any]:
async def noop_send(msg: dict[str, Any]) -> None:
pass
- with (
- patch("jwt.decode", side_effect=RuntimeError("boom")),
- pytest.raises(RuntimeError, match="boom"),
- ):
- await app(scope, noop_receive, noop_send)
+ sent: list[dict[str, Any]] = []
+
+ async def capture_send(msg: dict[str, Any]) -> None:
+ sent.append(msg)
+
+ with patch("jwt.decode", side_effect=RuntimeError("boom")):
+ await app(scope, noop_receive, capture_send)
+
+ # With dual auth, unexpected errors in self-signed validation
+ # fall through gracefully (returns 401, not a crash)
+ body = b"".join(m.get("body", b"") for m in sent if m["type"] == "http.response.body")
+ assert b"Invalid token" in body
class TestServerAuthWiring:
@@ -552,7 +559,7 @@ def test_http_with_mount_path_auth_no_secret_raises(
mock_app = MagicMock()
monkeypatch.setattr(server_mod.mcp, "streamable_http_app", lambda: mock_app)
- with pytest.raises(ValueError, match="AWARENESS_JWT_SECRET is required"):
+ with pytest.raises(ValueError, match="AWARENESS_AUTH_REQUIRED=true requires"):
server_mod._run()
def test_http_no_mount_path_auth_required_wraps_with_auth_middleware(
diff --git a/tests/test_oauth.py b/tests/test_oauth.py
new file mode 100644
index 0000000..4b9735f
--- /dev/null
+++ b/tests/test_oauth.py
@@ -0,0 +1,773 @@
+# mcp-awareness — ambient system awareness for AI agents
+# Copyright (C) 2026 Chris Means
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+"""Tests for OAuth 2.1 resource server (JWKS validation, auto-provisioning, metadata)."""
+
+from __future__ import annotations
+
+import json
+from datetime import datetime, timedelta, timezone
+from typing import Any
+from unittest.mock import MagicMock
+
+import jwt
+import pytest
+from cryptography.hazmat.primitives.asymmetric import rsa
+from cryptography.hazmat.primitives.serialization import (
+ Encoding,
+ NoEncryption,
+ PrivateFormat,
+)
+
+from mcp_awareness.middleware import AuthMiddleware, WellKnownMiddleware
+from mcp_awareness.oauth import OAuthTokenValidator
+
+# ---------------------------------------------------------------------------
+# RSA key pair for testing
+# ---------------------------------------------------------------------------
+
+_TEST_PRIVATE_KEY = rsa.generate_private_key(public_exponent=65537, key_size=2048)
+_TEST_PUBLIC_KEY = _TEST_PRIVATE_KEY.public_key()
+
+TEST_ISSUER = "https://auth.example.com"
+TEST_AUDIENCE = "awareness-test"
+TEST_OWNER = "test-owner"
+
+
+def _make_token(
+ sub: str = TEST_OWNER,
+ issuer: str = TEST_ISSUER,
+ audience: str = TEST_AUDIENCE,
+ email: str | None = None,
+ name: str | None = None,
+ expired: bool = False,
+) -> str:
+ """Create a signed RS256 JWT for testing."""
+ now = datetime.now(timezone.utc)
+ payload: dict[str, Any] = {
+ "sub": sub,
+ "iss": issuer,
+ "aud": audience,
+ "iat": now,
+ "exp": now + timedelta(hours=-1 if expired else 1),
+ }
+ if email:
+ payload["email"] = email
+ if name:
+ payload["name"] = name
+
+ private_pem = _TEST_PRIVATE_KEY.private_bytes(Encoding.PEM, PrivateFormat.PKCS8, NoEncryption())
+ return jwt.encode(payload, private_pem, algorithm="RS256")
+
+
+# ---------------------------------------------------------------------------
+# OAuthTokenValidator tests
+# ---------------------------------------------------------------------------
+
+
+class TestOAuthTokenValidator:
+ def _make_validator(self, audience: str = TEST_AUDIENCE) -> OAuthTokenValidator:
+ validator = OAuthTokenValidator(
+ issuer=TEST_ISSUER,
+ audience=audience,
+ user_claim="sub",
+ )
+ return validator
+
+ def _mock_jwk_client(self, validator: OAuthTokenValidator) -> None:
+ """Replace the PyJWKClient with a mock that returns our test key."""
+ import time as _time
+
+ mock_client = MagicMock()
+ mock_signing_key = MagicMock()
+ mock_signing_key.key = _TEST_PUBLIC_KEY
+ mock_client.get_signing_key_from_jwt.return_value = mock_signing_key
+ validator._jwk_client = mock_client
+ # Prevent validate() from replacing our mock with a real client
+ validator._last_jwks_fetch = _time.monotonic()
+
+ def test_valid_token(self) -> None:
+ validator = self._make_validator()
+ self._mock_jwk_client(validator)
+ token = _make_token()
+ result = validator.validate(token)
+ assert result["owner_id"] == TEST_OWNER
+
+ def test_token_with_email_and_name(self) -> None:
+ validator = self._make_validator()
+ self._mock_jwk_client(validator)
+ token = _make_token(email="alice@example.com", name="Alice")
+ result = validator.validate(token)
+ assert result["owner_id"] == TEST_OWNER
+ assert result["email"] == "alice@example.com"
+ assert result["name"] == "Alice"
+
+ def test_expired_token_raises(self) -> None:
+ validator = self._make_validator()
+ self._mock_jwk_client(validator)
+ token = _make_token(expired=True)
+ with pytest.raises(jwt.ExpiredSignatureError):
+ validator.validate(token)
+
+ def test_wrong_issuer_raises(self) -> None:
+ validator = self._make_validator()
+ self._mock_jwk_client(validator)
+ token = _make_token(issuer="https://wrong-issuer.com")
+ with pytest.raises(jwt.InvalidIssuerError):
+ validator.validate(token)
+
+ def test_wrong_audience_raises(self) -> None:
+ validator = self._make_validator(audience="wrong-audience")
+ self._mock_jwk_client(validator)
+ token = _make_token()
+ with pytest.raises(jwt.InvalidAudienceError):
+ validator.validate(token)
+
+ def test_missing_sub_claim_raises(self) -> None:
+ validator = self._make_validator()
+ self._mock_jwk_client(validator)
+ # Create a token without sub
+ now = datetime.now(timezone.utc)
+ payload = {
+ "iss": TEST_ISSUER,
+ "aud": TEST_AUDIENCE,
+ "iat": now,
+ "exp": now + timedelta(hours=1),
+ }
+ private_pem = _TEST_PRIVATE_KEY.private_bytes(
+ Encoding.PEM, PrivateFormat.PKCS8, NoEncryption()
+ )
+ token = jwt.encode(payload, private_pem, algorithm="RS256")
+ with pytest.raises(jwt.InvalidTokenError, match="missing 'sub' claim"):
+ validator.validate(token)
+
+ def test_custom_user_claim(self) -> None:
+ validator = OAuthTokenValidator(
+ issuer=TEST_ISSUER,
+ audience=TEST_AUDIENCE,
+ user_claim="email",
+ )
+ self._mock_jwk_client(validator)
+ token = _make_token(email="alice@example.com")
+ result = validator.validate(token)
+ assert result["owner_id"] == "alice@example.com"
+
+ def test_explicit_jwks_uri(self) -> None:
+ """Explicit jwks_uri overrides the default derived from issuer."""
+ validator = OAuthTokenValidator(
+ issuer=TEST_ISSUER,
+ audience=TEST_AUDIENCE,
+ jwks_uri="https://custom.example.com/keys",
+ user_claim="sub",
+ )
+ assert validator._jwks_uri == "https://custom.example.com/keys"
+
+ def test_jwks_cache_refresh(self) -> None:
+ """JWKS client is refreshed when cache TTL expires."""
+ validator = self._make_validator()
+ self._mock_jwk_client(validator)
+ # Force cache to be expired
+ validator._last_jwks_fetch = 0.0
+ validator._jwks_cache_ttl = 0 # instant expiry
+ token = _make_token()
+ # This triggers cache refresh (creates new PyJWKClient) which fails
+ # because it tries to fetch from a fake URL. That's expected.
+ with pytest.raises(jwt.exceptions.PyJWKClientConnectionError):
+ validator.validate(token)
+
+ def test_no_audience_skips_validation(self) -> None:
+ validator = OAuthTokenValidator(
+ issuer=TEST_ISSUER,
+ audience="",
+ user_claim="sub",
+ )
+ self._mock_jwk_client(validator)
+ token = _make_token()
+ result = validator.validate(token)
+ assert result["owner_id"] == TEST_OWNER
+
+
+# ---------------------------------------------------------------------------
+# WellKnownMiddleware tests
+# ---------------------------------------------------------------------------
+
+
+class TestWellKnownMiddleware:
+ @pytest.mark.anyio
+ async def test_serves_protected_resource_metadata(self) -> None:
+ async def inner_app(scope: Any, receive: Any, send: Any) -> None:
+ pass
+
+ app = WellKnownMiddleware(inner_app, TEST_ISSUER, "localhost", 8420)
+ scope = {
+ "type": "http",
+ "path": "/.well-known/oauth-protected-resource",
+ "method": "GET",
+ "headers": [],
+ }
+
+ sent: list[dict[str, Any]] = []
+
+ async def noop_receive() -> dict[str, Any]:
+ return {"type": "http.request", "body": b""}
+
+ async def capture_send(msg: dict[str, Any]) -> None:
+ sent.append(msg)
+
+ await app(scope, noop_receive, capture_send)
+
+ body = b"".join(m.get("body", b"") for m in sent if m["type"] == "http.response.body")
+ data = json.loads(body)
+ assert data["authorization_servers"] == [TEST_ISSUER]
+ assert data["token_methods"] == ["Bearer"]
+ assert "/mcp" in data["resource"]
+
+ @pytest.mark.anyio
+ async def test_passes_through_other_paths(self) -> None:
+ called = False
+
+ async def inner_app(scope: Any, receive: Any, send: Any) -> None:
+ nonlocal called
+ called = True
+
+ app = WellKnownMiddleware(inner_app, TEST_ISSUER, "localhost", 8420)
+ scope = {"type": "http", "path": "/mcp", "method": "POST", "headers": []}
+
+ async def noop_receive() -> dict[str, Any]:
+ return {"type": "http.request", "body": b""}
+
+ async def noop_send(msg: dict[str, Any]) -> None:
+ pass
+
+ await app(scope, noop_receive, noop_send)
+ assert called
+
+
+# ---------------------------------------------------------------------------
+# AuthMiddleware dual auth tests
+# ---------------------------------------------------------------------------
+
+
+class TestDualAuth:
+ @pytest.mark.anyio
+ async def test_self_signed_jwt_still_works(self) -> None:
+ """Existing self-signed JWT auth continues to work with OAuth configured."""
+ secret = "test-secret-at-least-32-chars-long!"
+ token = jwt.encode({"sub": "alice"}, secret, algorithm="HS256")
+
+ called_with_owner: list[str] = []
+
+ async def inner_app(scope: Any, receive: Any, send: Any) -> None:
+ from mcp_awareness.server import _owner_ctx
+
+ called_with_owner.append(_owner_ctx.get())
+
+ mock_oauth = MagicMock()
+ app = AuthMiddleware(inner_app, jwt_secret=secret, oauth_validator=mock_oauth)
+
+ scope = {
+ "type": "http",
+ "path": "/mcp",
+ "method": "POST",
+ "headers": [(b"authorization", f"Bearer {token}".encode())],
+ }
+
+ async def noop_receive() -> dict[str, Any]:
+ return {"type": "http.request", "body": b""}
+
+ async def noop_send(msg: dict[str, Any]) -> None:
+ pass
+
+ await app(scope, noop_receive, noop_send)
+ assert called_with_owner == ["alice"]
+ # OAuth validator should NOT be called when self-signed JWT succeeds
+ mock_oauth.validate.assert_not_called()
+
+ @pytest.mark.anyio
+ async def test_oauth_fallback_when_self_signed_fails(self) -> None:
+ """When self-signed JWT fails, falls back to OAuth validation."""
+ oauth_token = "oauth-token-not-valid-as-self-signed"
+
+ mock_oauth = MagicMock()
+ mock_oauth.validate.return_value = {
+ "owner_id": "oauth-user",
+ "email": "oauth@example.com",
+ }
+
+ called_with_owner: list[str] = []
+
+ async def inner_app(scope: Any, receive: Any, send: Any) -> None:
+ from mcp_awareness.server import _owner_ctx
+
+ called_with_owner.append(_owner_ctx.get())
+
+ app = AuthMiddleware(
+ inner_app,
+ jwt_secret="some-secret-at-least-32-chars!!!",
+ oauth_validator=mock_oauth,
+ auto_provision=False,
+ )
+
+ scope = {
+ "type": "http",
+ "path": "/mcp",
+ "method": "POST",
+ "headers": [(b"authorization", f"Bearer {oauth_token}".encode())],
+ }
+
+ async def noop_receive() -> dict[str, Any]:
+ return {"type": "http.request", "body": b""}
+
+ async def noop_send(msg: dict[str, Any]) -> None:
+ pass
+
+ await app(scope, noop_receive, noop_send)
+ assert called_with_owner == ["oauth-user"]
+ mock_oauth.validate.assert_called_once_with(oauth_token)
+
+ @pytest.mark.anyio
+ async def test_well_known_bypasses_auth(self) -> None:
+ """/.well-known/ paths should not require authentication."""
+ sent: list[dict[str, Any]] = []
+
+ async def inner_app(scope: Any, receive: Any, send: Any) -> None:
+ pass
+
+ app = AuthMiddleware(inner_app, jwt_secret="secret-at-least-32-chars!!!!!!")
+
+ scope = {
+ "type": "http",
+ "path": "/.well-known/oauth-protected-resource",
+ "method": "GET",
+ "headers": [], # No auth header
+ }
+
+ async def noop_receive() -> dict[str, Any]:
+ return {"type": "http.request", "body": b""}
+
+ async def capture_send(msg: dict[str, Any]) -> None:
+ sent.append(msg)
+
+ await app(scope, noop_receive, capture_send)
+ # Should NOT return 401 — well-known paths are public
+ status_codes = [m.get("status") for m in sent if m["type"] == "http.response.start"]
+ assert 401 not in status_codes
+
+ @pytest.mark.anyio
+ async def test_resolve_user_called_on_oauth(self) -> None:
+ """When OAuth succeeds, _resolve_user is called with claims."""
+ mock_oauth = MagicMock()
+ mock_oauth.validate.return_value = {
+ "owner_id": "new-user",
+ "email": "new@example.com",
+ "name": "New User",
+ "oauth_subject": "sub-123",
+ "oauth_issuer": TEST_ISSUER,
+ }
+
+ called_with_owner: list[str] = []
+
+ async def inner_app(scope: Any, receive: Any, send: Any) -> None:
+ from mcp_awareness.server import _owner_ctx
+
+ called_with_owner.append(_owner_ctx.get())
+
+ app = AuthMiddleware(
+ inner_app,
+ jwt_secret="",
+ oauth_validator=mock_oauth,
+ auto_provision=False,
+ )
+
+ scope = {
+ "type": "http",
+ "path": "/mcp",
+ "method": "POST",
+ "headers": [(b"authorization", b"Bearer oauth-token")],
+ }
+
+ async def noop_receive() -> dict[str, Any]:
+ return {"type": "http.request", "body": b""}
+
+ async def noop_send(msg: dict[str, Any]) -> None:
+ pass
+
+ await app(scope, noop_receive, noop_send)
+ assert called_with_owner == ["new-user"]
+
+ @pytest.mark.anyio
+ async def test_401_includes_www_authenticate(self) -> None:
+ """401 responses include WWW-Authenticate with resource_metadata URL."""
+ sent: list[dict[str, Any]] = []
+
+ async def inner_app(scope: Any, receive: Any, send: Any) -> None:
+ pass
+
+ mock_oauth = MagicMock()
+ mock_oauth.validate.side_effect = Exception("fail")
+ app = AuthMiddleware(
+ inner_app,
+ jwt_secret="",
+ oauth_validator=mock_oauth,
+ resource_metadata_url="/.well-known/oauth-protected-resource",
+ )
+
+ scope = {
+ "type": "http",
+ "path": "/mcp",
+ "method": "POST",
+ "headers": [(b"authorization", b"Bearer bad-token")],
+ }
+
+ async def noop_receive() -> dict[str, Any]:
+ return {"type": "http.request", "body": b""}
+
+ async def capture_send(msg: dict[str, Any]) -> None:
+ sent.append(msg)
+
+ await app(scope, noop_receive, capture_send)
+
+ response_start = next(m for m in sent if m["type"] == "http.response.start")
+ assert response_start["status"] == 401
+ headers = dict(response_start.get("headers", []))
+ www_auth = headers.get(b"www-authenticate", b"").decode()
+ assert "resource_metadata" in www_auth
+
+
+# ---------------------------------------------------------------------------
+# Auto-provisioning tests
+# ---------------------------------------------------------------------------
+
+
+class TestAutoProvisionIntegration:
+ """Integration test: middleware auto-provision with a real store."""
+
+ @pytest.mark.anyio
+ async def test_ensure_user_creates_record(self, store: Any, monkeypatch: Any) -> None:
+ """_ensure_user calls store.create_user_if_not_exists through the server module."""
+ import mcp_awareness.server as server_mod
+
+ # Point the middleware at our test store
+ monkeypatch.setattr(server_mod, "store", store)
+
+ async def inner_app(scope: Any, receive: Any, send: Any) -> None:
+ pass
+
+ mock_oauth = MagicMock()
+ mock_oauth.validate.return_value = {
+ "owner_id": "integration-user",
+ "email": "int@example.com",
+ "name": "Integration",
+ "oauth_subject": "int-sub",
+ "oauth_issuer": TEST_ISSUER,
+ }
+
+ app = AuthMiddleware(
+ inner_app,
+ jwt_secret="",
+ oauth_validator=mock_oauth,
+ auto_provision=True,
+ )
+
+ scope = {
+ "type": "http",
+ "path": "/mcp",
+ "method": "POST",
+ "headers": [(b"authorization", b"Bearer oauth-token")],
+ }
+
+ async def noop_receive() -> dict[str, Any]:
+ return {"type": "http.request", "body": b""}
+
+ async def noop_send(msg: dict[str, Any]) -> None:
+ pass
+
+ await app(scope, noop_receive, noop_send)
+
+ # Verify user was created in the real store
+ user = store.get_user("integration-user")
+ assert user is not None
+ assert user["email"] == "int@example.com"
+
+ @pytest.mark.anyio
+ async def test_resolve_finds_already_linked_user(self, store: Any, monkeypatch: Any) -> None:
+ """OAuth login resolves to existing user via oauth_subject lookup."""
+ import mcp_awareness.server as server_mod
+
+ monkeypatch.setattr(server_mod, "store", store)
+
+ # Pre-create a linked user
+ store.create_user_if_not_exists(
+ "linked-alice", "alice@example.com", "Alice", "alice-sub", TEST_ISSUER
+ )
+
+ called_with_owner: list[str] = []
+
+ async def inner_app(scope: Any, receive: Any, send: Any) -> None:
+ from mcp_awareness.server import _owner_ctx
+
+ called_with_owner.append(_owner_ctx.get())
+
+ mock_oauth = MagicMock()
+ mock_oauth.validate.return_value = {
+ "owner_id": "alice-sub",
+ "email": "alice@example.com",
+ "oauth_subject": "alice-sub",
+ "oauth_issuer": TEST_ISSUER,
+ }
+
+ app = AuthMiddleware(
+ inner_app, jwt_secret="", oauth_validator=mock_oauth, auto_provision=False
+ )
+ scope = {
+ "type": "http",
+ "path": "/mcp",
+ "method": "POST",
+ "headers": [(b"authorization", b"Bearer oauth-token")],
+ }
+
+ async def noop_receive() -> dict[str, Any]:
+ return {"type": "http.request", "body": b""}
+
+ async def noop_send(msg: dict[str, Any]) -> None:
+ pass
+
+ await app(scope, noop_receive, noop_send)
+ # Should resolve to the existing user's ID, not the raw sub claim
+ assert called_with_owner == ["linked-alice"]
+
+ @pytest.mark.anyio
+ async def test_resolve_links_pre_provisioned_user_by_email(
+ self, store: Any, monkeypatch: Any
+ ) -> None:
+ """First OAuth login links to a pre-provisioned user matched by email."""
+ import mcp_awareness.server as server_mod
+
+ monkeypatch.setattr(server_mod, "store", store)
+
+ # Pre-provision via CLI (no OAuth identity)
+ store.create_user_if_not_exists("cli-bob", "bob@example.com", "Bob")
+
+ called_with_owner: list[str] = []
+
+ async def inner_app(scope: Any, receive: Any, send: Any) -> None:
+ from mcp_awareness.server import _owner_ctx
+
+ called_with_owner.append(_owner_ctx.get())
+
+ mock_oauth = MagicMock()
+ mock_oauth.validate.return_value = {
+ "owner_id": "bob-sub-xyz",
+ "email": "bob@example.com",
+ "oauth_subject": "bob-sub-xyz",
+ "oauth_issuer": TEST_ISSUER,
+ }
+
+ app = AuthMiddleware(
+ inner_app, jwt_secret="", oauth_validator=mock_oauth, auto_provision=False
+ )
+ scope = {
+ "type": "http",
+ "path": "/mcp",
+ "method": "POST",
+ "headers": [(b"authorization", b"Bearer oauth-token")],
+ }
+
+ async def noop_receive() -> dict[str, Any]:
+ return {"type": "http.request", "body": b""}
+
+ async def noop_send(msg: dict[str, Any]) -> None:
+ pass
+
+ await app(scope, noop_receive, noop_send)
+ # Should resolve to the pre-provisioned user's ID via email linking
+ assert called_with_owner == ["cli-bob"]
+ # Verify OAuth identity was linked
+ user = store.get_user_by_oauth(TEST_ISSUER, "bob-sub-xyz")
+ assert user is not None
+ assert user["id"] == "cli-bob"
+
+
+class TestAutoProvisionFailure:
+ """Verify _ensure_user swallows exceptions gracefully."""
+
+ @pytest.mark.anyio
+ async def test_ensure_user_exception_swallowed(self, monkeypatch: Any) -> None:
+ """Auto-provisioning failure must not block the request."""
+ import mcp_awareness.server as server_mod
+
+ broken_store = MagicMock()
+ broken_store.get_user_by_oauth.return_value = None
+ broken_store.link_oauth_identity.return_value = None
+ broken_store.create_user_if_not_exists.side_effect = RuntimeError("db down")
+ monkeypatch.setattr(server_mod, "store", broken_store)
+
+ mock_oauth = MagicMock()
+ mock_oauth.validate.return_value = {
+ "owner_id": "failing-user",
+ "oauth_subject": "sub",
+ "oauth_issuer": TEST_ISSUER,
+ }
+
+ called_with_owner: list[str] = []
+
+ async def inner_app(scope: Any, receive: Any, send: Any) -> None:
+ from mcp_awareness.server import _owner_ctx
+
+ called_with_owner.append(_owner_ctx.get())
+
+ app = AuthMiddleware(
+ inner_app, jwt_secret="", oauth_validator=mock_oauth, auto_provision=True
+ )
+ scope = {
+ "type": "http",
+ "path": "/mcp",
+ "method": "POST",
+ "headers": [(b"authorization", b"Bearer token")],
+ }
+
+ async def noop_receive() -> dict[str, Any]:
+ return {"type": "http.request", "body": b""}
+
+ async def noop_send(msg: dict[str, Any]) -> None:
+ pass
+
+ await app(scope, noop_receive, noop_send)
+ # Request should succeed despite provisioning failure
+ assert called_with_owner == ["failing-user"]
+
+
+class TestServerWiring:
+ def test_build_oauth_validator_returns_none_without_issuer(self) -> None:
+ """No OAuth validator when OAUTH_ISSUER is empty."""
+ from mcp_awareness import server as server_mod
+
+ original = server_mod.OAUTH_ISSUER
+ try:
+ server_mod.OAUTH_ISSUER = ""
+ assert server_mod._build_oauth_validator() is None
+ finally:
+ server_mod.OAUTH_ISSUER = original
+
+ def test_wrap_with_auth_uses_oauth_when_issuer_set(self) -> None:
+ """_wrap_with_auth creates AuthMiddleware with OAuth validator when issuer is set."""
+ from mcp_awareness import server as server_mod
+
+ orig_issuer = server_mod.OAUTH_ISSUER
+ orig_required = server_mod.AUTH_REQUIRED
+ orig_secret = server_mod.JWT_SECRET
+ try:
+ server_mod.OAUTH_ISSUER = TEST_ISSUER
+ server_mod.AUTH_REQUIRED = True
+ server_mod.JWT_SECRET = "" # No self-signed — OAuth only
+
+ async def dummy(scope: Any, receive: Any, send: Any) -> None:
+ pass
+
+ wrapped = server_mod._wrap_with_auth(dummy)
+ assert isinstance(wrapped, AuthMiddleware)
+ assert wrapped.oauth_validator is not None
+ finally:
+ server_mod.OAUTH_ISSUER = orig_issuer
+ server_mod.AUTH_REQUIRED = orig_required
+ server_mod.JWT_SECRET = orig_secret
+
+ def test_build_oauth_validator_returns_validator_with_issuer(self) -> None:
+ """OAuth validator created when OAUTH_ISSUER is set."""
+ from mcp_awareness import server as server_mod
+
+ original = server_mod.OAUTH_ISSUER
+ try:
+ server_mod.OAUTH_ISSUER = TEST_ISSUER
+ validator = server_mod._build_oauth_validator()
+ assert validator is not None
+ from mcp_awareness.oauth import OAuthTokenValidator
+
+ assert isinstance(validator, OAuthTokenValidator)
+ finally:
+ server_mod.OAUTH_ISSUER = original
+
+
+class TestAutoProvisioning:
+ def test_create_user_with_oauth_identity(self, store: Any) -> None:
+ """Auto-provisioning stores OAuth identity fields."""
+ store.create_user_if_not_exists(
+ "oauth-carol", "carol@example.com", "Carol", "carol-sub-123", TEST_ISSUER
+ )
+ user = store.get_user("oauth-carol")
+ assert user is not None
+ assert user["id"] == "oauth-carol"
+
+ def test_get_user_by_oauth(self, store: Any) -> None:
+ """Look up user by OAuth issuer + subject pair."""
+ store.create_user_if_not_exists(
+ "oauth-dan", "dan@example.com", "Dan", "dan-sub-456", TEST_ISSUER
+ )
+ user = store.get_user_by_oauth(TEST_ISSUER, "dan-sub-456")
+ assert user is not None
+ assert user["id"] == "oauth-dan"
+
+ def test_get_user_by_oauth_not_found(self, store: Any) -> None:
+ user = store.get_user_by_oauth(TEST_ISSUER, "nonexistent-sub")
+ assert user is None
+
+ def test_create_user_if_not_exists(self, store: Any) -> None:
+ """Auto-provisioning creates a user that didn't exist."""
+ store.create_user_if_not_exists("oauth-alice", "alice@example.com", "Alice")
+ user = store.get_user("oauth-alice")
+ assert user is not None
+ assert user["id"] == "oauth-alice"
+ assert user["email"] == "alice@example.com"
+ assert user["display_name"] == "Alice"
+
+ def test_create_user_if_not_exists_no_op_when_exists(self, store: Any) -> None:
+ """Auto-provisioning is a no-op for existing users."""
+ store.create_user_if_not_exists("oauth-bob", "bob@example.com", "Bob")
+ # Create again with different email — should NOT update
+ store.create_user_if_not_exists("oauth-bob", "new@example.com", "New Bob")
+ user = store.get_user("oauth-bob")
+ assert user is not None
+ assert user["email"] == "bob@example.com" # Original preserved
+
+ def test_get_user_returns_none_for_unknown(self, store: Any) -> None:
+ user = store.get_user("nonexistent-user")
+ assert user is None
+
+ def test_link_oauth_identity_by_email(self, store: Any) -> None:
+ """Pre-provisioned user (CLI) gets linked on first OAuth login by email match."""
+ # Pre-provision via CLI (no OAuth identity yet)
+ store.create_user_if_not_exists("pre-user", "pre@example.com", "Pre User")
+ # First OAuth login — link by email
+ linked_id = store.link_oauth_identity("pre-sub-789", TEST_ISSUER, "pre@example.com")
+ assert linked_id == "pre-user"
+ # Verify OAuth columns populated
+ user = store.get_user_by_oauth(TEST_ISSUER, "pre-sub-789")
+ assert user is not None
+ assert user["id"] == "pre-user"
+
+ def test_link_oauth_identity_no_match(self, store: Any) -> None:
+ """Linking returns None when no user has the given email."""
+ linked_id = store.link_oauth_identity("orphan-sub", TEST_ISSUER, "nobody@example.com")
+ assert linked_id is None
+
+ def test_link_oauth_identity_already_linked(self, store: Any) -> None:
+ """Linking is a no-op if user already has an OAuth identity."""
+ store.create_user_if_not_exists(
+ "linked-user", "linked@example.com", "Linked", "existing-sub", TEST_ISSUER
+ )
+ # Try to link again with different sub — should not overwrite
+ linked_id = store.link_oauth_identity("new-sub", TEST_ISSUER, "linked@example.com")
+ assert linked_id is None # Already linked, no update