diff --git a/CHANGELOG.md b/CHANGELOG.md index 443eedd..779e817 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Added +- **OAuth 2.1 resource server**: provider-agnostic JWKS-based token validation for external OAuth providers (WorkOS, Auth0, Cloudflare Access, Keycloak, etc.) +- **Dual auth**: self-signed JWTs (via CLI) and OAuth provider tokens both accepted — OAuth for interactive clients, self-signed for edge providers/scripts +- **User auto-provisioning**: auto-create user record on first valid OAuth login (`AWARENESS_OAUTH_AUTO_PROVISION`, default: false for tighter control during early access) +- **Well-known metadata**: `/.well-known/oauth-protected-resource` (RFC 9728) for OAuth discovery by MCP clients +- **OAuth env vars**: `AWARENESS_OAUTH_ISSUER`, `AWARENESS_OAUTH_AUDIENCE`, `AWARENESS_OAUTH_JWKS_URI`, `AWARENESS_OAUTH_USER_CLAIM`, `AWARENESS_OAUTH_AUTO_PROVISION` - **JWT auth middleware**: opt-in via `AWARENESS_AUTH_REQUIRED=true`, validates Bearer tokens, extracts owner_id from `sub` claim - **Row-level security**: Postgres RLS policies on all data tables as defense-in-depth alongside application-level owner_id filtering - **CLI: `mcp-awareness-user`**: add/list/set-password/export/delete users with email normalization, E.164 phone validation, argon2id password hashing diff --git a/alembic/versions/i4d5e6f7g8h9_add_oauth_columns.py b/alembic/versions/i4d5e6f7g8h9_add_oauth_columns.py new file mode 100644 index 0000000..ad5a757 --- /dev/null +++ b/alembic/versions/i4d5e6f7g8h9_add_oauth_columns.py @@ -0,0 +1,58 @@ +# mcp-awareness — ambient system awareness for AI agents +# Copyright (C) 2026 Chris Means +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +"""add OAuth identity columns to users table + +Revision ID: i4d5e6f7g8h9 +Revises: h3c4d5e6f7g8 +Create Date: 2026-03-29 20:00:00.000000 + +""" + +from __future__ import annotations + +from collections.abc import Sequence + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "i4d5e6f7g8h9" +down_revision: str | Sequence[str] | None = "h3c4d5e6f7g8" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + # OAuth identity: sub claim + issuer for provider-agnostic lookup + op.execute("ALTER TABLE users ADD COLUMN IF NOT EXISTS oauth_subject TEXT") + op.execute("ALTER TABLE users ADD COLUMN IF NOT EXISTS oauth_issuer TEXT") + # Unique constraint: one OAuth identity per provider per user + op.execute( + "CREATE UNIQUE INDEX IF NOT EXISTS ix_users_oauth_identity " + "ON users (oauth_issuer, oauth_subject) WHERE oauth_issuer IS NOT NULL" + ) + # Fast lookup index for every authenticated request + op.execute( + "CREATE INDEX IF NOT EXISTS ix_users_oauth_subject " + "ON users (oauth_subject) WHERE oauth_subject IS NOT NULL" + ) + + +def downgrade() -> None: + op.execute("DROP INDEX IF EXISTS ix_users_oauth_subject") + op.execute("DROP INDEX IF EXISTS ix_users_oauth_identity") + op.execute("ALTER TABLE users DROP COLUMN IF EXISTS oauth_issuer") + op.execute("ALTER TABLE users DROP COLUMN IF EXISTS oauth_subject") diff --git a/pyproject.toml b/pyproject.toml index c45025d..9bbfe73 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ dependencies = [ "psycopg_pool>=3.2", "alembic>=1.13", "sqlalchemy>=2.0", - "PyJWT>=2.8", + "PyJWT[crypto]>=2.8", "argon2-cffi>=23.1", "phonenumbers>=8.13", "zxcvbn>=4.4", diff --git a/src/mcp_awareness/middleware.py b/src/mcp_awareness/middleware.py index 7d32d09..40e6805 100644 --- a/src/mcp_awareness/middleware.py +++ b/src/mcp_awareness/middleware.py @@ -93,13 +93,61 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: await self.app(scope, receive, send) +class WellKnownMiddleware: + """Serve /.well-known/oauth-protected-resource (RFC 9728).""" + + def __init__( + self, + app: ASGIApp, + oauth_issuer: str, + host: str, + port: int, + mount_path: str = "", + ) -> None: + self.app = app + self.oauth_issuer = oauth_issuer.rstrip("/") + self.host = host + self.port = port + self.mount_path = mount_path + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] == "http": + path = scope.get("path", "") + if path == "/.well-known/oauth-protected-resource": + metadata = { + "resource": f"https://{self.host}:{self.port}{self.mount_path}/mcp", + "authorization_servers": [self.oauth_issuer], + "token_methods": ["Bearer"], + } + resp = JSONResponse(metadata) + await resp(scope, receive, send) + return + await self.app(scope, receive, send) + + class AuthMiddleware: - """Validate JWT Bearer token and set owner context.""" + """Validate JWT Bearer token and set owner context. - def __init__(self, app: ASGIApp, jwt_secret: str, algorithm: str = "HS256") -> None: + Supports dual auth: self-signed JWTs (via shared secret) and OAuth provider + tokens (via JWKS). Self-signed is tried first; if it fails and an OAuth + validator is configured, the token is validated against the provider's keys. + """ + + def __init__( + self, + app: ASGIApp, + jwt_secret: str, + algorithm: str = "HS256", + oauth_validator: object | None = None, + auto_provision: bool = True, + resource_metadata_url: str = "", + ) -> None: self.app = app self.jwt_secret = jwt_secret self.algorithm = algorithm + self.oauth_validator = oauth_validator + self.auto_provision = auto_provision + self.resource_metadata_url = resource_metadata_url async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] != "http": @@ -107,8 +155,8 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: return path = scope.get("path", "") - # Skip auth for health, favicon, and non-MCP paths - if path in ("/health", "/favicon.ico"): + # Skip auth for health, favicon, well-known, and non-MCP paths + if path in ("/health", "/favicon.ico") or path.startswith("/.well-known/"): await self.app(scope, receive, send) return @@ -116,32 +164,23 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: headers = dict(scope.get("headers", [])) auth_header = headers.get(b"authorization", b"").decode() if not auth_header.startswith("Bearer "): - resp = JSONResponse( - {"error": "Missing or invalid Authorization header"}, status_code=401 - ) + resp = self._unauthorized("Missing or invalid Authorization header") await resp(scope, receive, send) return token = auth_header[7:] # Strip "Bearer " - try: - import jwt - payload = jwt.decode(token, self.jwt_secret, algorithms=[self.algorithm]) - owner_id: str | None = payload.get("sub") - if not owner_id: - resp = JSONResponse({"error": "JWT missing 'sub' claim"}, status_code=401) - await resp(scope, receive, send) - return - except Exception as exc: - # Handle both ExpiredSignatureError and InvalidTokenError - import jwt as jwt_mod + # Try self-signed JWT first + owner_id, error = self._try_self_signed(token) - if isinstance(exc, jwt_mod.ExpiredSignatureError): - resp = JSONResponse({"error": "Token expired"}, status_code=401) - elif isinstance(exc, jwt_mod.InvalidTokenError): - resp = JSONResponse({"error": "Invalid token"}, status_code=401) - else: - raise + # Fall back to OAuth provider validation + if owner_id is None and self.oauth_validator is not None: + owner_id = await self._try_oauth(token) + if owner_id is not None: + error = None + + if owner_id is None: + resp = self._unauthorized(error or "Invalid token") await resp(scope, receive, send) return @@ -153,3 +192,103 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: await self.app(scope, receive, send) finally: _owner_ctx.reset(token_reset) + + def _try_self_signed(self, token: str) -> tuple[str | None, str | None]: + """Validate a self-signed JWT (from mcp-awareness-token CLI). + + Returns (owner_id, error_message). On success error is None. + On failure owner_id is None and error carries the reason. + """ + if not self.jwt_secret: + return None, None + try: + import jwt + + payload = jwt.decode(token, self.jwt_secret, algorithms=[self.algorithm]) + owner_id: str | None = payload.get("sub") + if owner_id: + return owner_id, None + return None, "JWT missing 'sub' claim" + except Exception as exc: + import jwt as jwt_mod + + if isinstance(exc, jwt_mod.ExpiredSignatureError): + return None, "Token expired" + if isinstance(exc, jwt_mod.InvalidTokenError): + return None, "Invalid token" + return None, None + + async def _try_oauth(self, token: str) -> str | None: + """Validate an OAuth token against the external provider's JWKS.""" + from .oauth import OAuthTokenValidator + + validator: OAuthTokenValidator = self.oauth_validator # type: ignore[assignment] + try: + claims = validator.validate(token) + except Exception: + return None + + owner_id = claims["owner_id"] + oauth_subject = claims.get("oauth_subject") + oauth_issuer = claims.get("oauth_issuer") + email = claims.get("email") + + # Resolve user identity: OAuth lookup → email link → auto-provision + resolved_id = self._resolve_user( + owner_id, email, claims.get("name"), oauth_subject, oauth_issuer + ) + + return resolved_id or owner_id + + def _resolve_user( + self, + owner_id: str, + email: str | None, + display_name: str | None, + oauth_subject: str | None, + oauth_issuer: str | None, + ) -> str | None: + """Resolve OAuth token to a local user, linking or creating as needed. + + Resolution order: + 1. Look up by OAuth identity (issuer + subject) — already linked user + 2. If email present, try to link to a pre-provisioned user by email + 3. If auto_provision enabled, create a new user + 4. Otherwise return None (use owner_id from token as-is) + """ + try: + from .server import store + + # 1. Already linked? + if oauth_issuer and oauth_subject: + existing = store.get_user_by_oauth(oauth_issuer, oauth_subject) + if existing: + return str(existing["id"]) + + # 2. Pre-provisioned user with matching email? Link on first login. + if email and oauth_subject and oauth_issuer: + linked_id = store.link_oauth_identity(oauth_subject, oauth_issuer, email) + if linked_id: + return str(linked_id) + + # 3. Auto-provision new user + if self.auto_provision: + store.create_user_if_not_exists( + owner_id, email, display_name, oauth_subject, oauth_issuer + ) + return owner_id + + except Exception: + # Don't fail the request if user resolution fails + pass + + return None + + def _unauthorized(self, message: str) -> JSONResponse: + """Build a 401 response with proper WWW-Authenticate header.""" + headers: dict[str, str] = {} + if self.resource_metadata_url: + headers["WWW-Authenticate"] = f'Bearer resource_metadata="{self.resource_metadata_url}"' + else: + headers["WWW-Authenticate"] = "Bearer" + return JSONResponse({"error": message}, status_code=401, headers=headers) diff --git a/src/mcp_awareness/oauth.py b/src/mcp_awareness/oauth.py new file mode 100644 index 0000000..7df23a8 --- /dev/null +++ b/src/mcp_awareness/oauth.py @@ -0,0 +1,103 @@ +# mcp-awareness — ambient system awareness for AI agents +# Copyright (C) 2026 Chris Means +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +"""OAuth 2.1 resource server — JWKS-based token validation for external providers.""" + +from __future__ import annotations + +import time + +import jwt +from jwt import PyJWKClient + + +class OAuthTokenValidator: + """Validates OAuth access tokens (JWTs) against an external provider's JWKS. + + Provider-agnostic: works with any OIDC-compliant provider (WorkOS, Auth0, + Cloudflare Access, Keycloak, AWS Cognito, etc.). + """ + + def __init__( + self, + issuer: str, + audience: str = "", + jwks_uri: str = "", + user_claim: str = "sub", + jwks_cache_ttl: int = 3600, + ) -> None: + self.issuer = issuer.rstrip("/") + self.audience = audience + self.user_claim = user_claim + + # JWKS URI: explicit override or derive from issuer + if jwks_uri: + self._jwks_uri = jwks_uri + else: + self._jwks_uri = f"{self.issuer}/.well-known/jwks.json" + + self._jwk_client = PyJWKClient(self._jwks_uri, cache_jwk_set=True) + self._jwks_cache_ttl = jwks_cache_ttl + self._last_jwks_fetch: float = 0.0 + + def validate(self, token: str) -> dict[str, str]: + """Validate an OAuth JWT and return extracted identity claims. + + Returns: + dict with 'owner_id' and optional 'email', 'name' keys. + + Raises: + jwt.InvalidTokenError: if token is invalid, expired, or unverifiable. + """ + # Refresh JWKS cache if stale + now = time.monotonic() + if now - self._last_jwks_fetch > self._jwks_cache_ttl: + self._jwk_client = PyJWKClient(self._jwks_uri, cache_jwk_set=True) + self._last_jwks_fetch = now + + signing_key = self._jwk_client.get_signing_key_from_jwt(token) + + # Build kwargs for jwt.decode + kwargs: dict[str, object] = { + "key": signing_key.key, + "algorithms": ["RS256", "ES256"], + "issuer": self.issuer, + } + if self.audience: + kwargs["audience"] = self.audience + else: + kwargs["options"] = {"verify_aud": False} + + payload = jwt.decode(token, **kwargs) # type: ignore[arg-type] + + # Extract owner_id from configured claim + owner_id = payload.get(self.user_claim) + if not owner_id: + raise jwt.InvalidTokenError(f"Token missing '{self.user_claim}' claim") + + result: dict[str, str] = {"owner_id": str(owner_id)} + + # Extract identity fields for auto-provisioning and lookup + if "sub" in payload: + result["oauth_subject"] = str(payload["sub"]) + if "iss" in payload: + result["oauth_issuer"] = str(payload["iss"]) + if "email" in payload: + result["email"] = str(payload["email"]) + if "name" in payload: + result["name"] = str(payload["name"]) + + return result diff --git a/src/mcp_awareness/postgres_store.py b/src/mcp_awareness/postgres_store.py index 9afb0b2..d96cbf5 100644 --- a/src/mcp_awareness/postgres_store.py +++ b/src/mcp_awareness/postgres_store.py @@ -1171,6 +1171,52 @@ def get_referencing_entries(self, owner_id: str, entry_id: str) -> list[Entry]: (json.dumps([entry_id]),), ) + # ------------------------------------------------------------------ + # User operations (for OAuth auto-provisioning) + # ------------------------------------------------------------------ + + def get_user(self, user_id: str) -> dict[str, Any] | None: + """Look up a user by ID. Returns dict or None if not found.""" + with self._pool.connection() as conn, conn.cursor(row_factory=dict_row) as cur: + cur.execute(_load_sql("get_user"), (user_id,)) + return cur.fetchone() + + def create_user_if_not_exists( + self, + user_id: str, + email: str | None = None, + display_name: str | None = None, + oauth_subject: str | None = None, + oauth_issuer: str | None = None, + ) -> None: + """Auto-provision a user on first OAuth login. No-op if user exists.""" + with self._pool.connection() as conn, conn.transaction(), conn.cursor() as cur: + cur.execute( + _load_sql("create_user_auto"), + (user_id, email, display_name, oauth_subject, oauth_issuer), + ) + + def get_user_by_oauth(self, oauth_issuer: str, oauth_subject: str) -> dict[str, Any] | None: + """Look up a user by OAuth identity. Returns dict or None.""" + with self._pool.connection() as conn, conn.cursor(row_factory=dict_row) as cur: + cur.execute(_load_sql("get_user_by_oauth"), (oauth_issuer, oauth_subject)) + return cur.fetchone() + + def link_oauth_identity(self, oauth_subject: str, oauth_issuer: str, email: str) -> str | None: + """Link an OAuth identity to a pre-provisioned user matched by email. + + Returns the user ID if linked, None if no matching user found. + Only links if the user's oauth_subject is currently NULL (first-time link). + """ + with ( + self._pool.connection() as conn, + conn.transaction(), + conn.cursor(row_factory=dict_row) as cur, + ): + cur.execute(_load_sql("link_oauth_identity"), (oauth_subject, oauth_issuer, email)) + row = cur.fetchone() + return str(row["id"]) if row else None + def clear(self, owner_id: str) -> None: with self._pool.connection() as conn, conn.transaction(), conn.cursor() as cur: self._set_rls_context(cur, owner_id) diff --git a/src/mcp_awareness/server.py b/src/mcp_awareness/server.py index 44a47b3..5724ec4 100644 --- a/src/mcp_awareness/server.py +++ b/src/mcp_awareness/server.py @@ -80,6 +80,13 @@ JWT_SECRET = os.environ.get("AWARENESS_JWT_SECRET", "") JWT_ALGORITHM = os.environ.get("AWARENESS_JWT_ALGORITHM", "HS256") +# OAuth — external provider (WorkOS, Auth0, Cloudflare Access, Keycloak, etc.) +OAUTH_ISSUER = os.environ.get("AWARENESS_OAUTH_ISSUER", "") +OAUTH_AUDIENCE = os.environ.get("AWARENESS_OAUTH_AUDIENCE", "") +OAUTH_JWKS_URI = os.environ.get("AWARENESS_OAUTH_JWKS_URI", "") +OAUTH_USER_CLAIM = os.environ.get("AWARENESS_OAUTH_USER_CLAIM", "sub") +OAUTH_AUTO_PROVISION = os.environ.get("AWARENESS_OAUTH_AUTO_PROVISION", "false").lower() == "true" + # Embedding provider — optional, configured via env vars EMBEDDING_PROVIDER = os.environ.get("AWARENESS_EMBEDDING_PROVIDER", "") EMBEDDING_MODEL = os.environ.get("AWARENESS_EMBEDDING_MODEL", "nomic-embed-text") @@ -332,24 +339,67 @@ def main() -> None: print("Shutdown requested — exiting.", flush=True) +def _build_oauth_validator() -> object | None: + """Create an OAuthTokenValidator if an external issuer is configured.""" + if not OAUTH_ISSUER: + return None + from mcp_awareness.oauth import OAuthTokenValidator + + return OAuthTokenValidator( + issuer=OAUTH_ISSUER, + audience=OAUTH_AUDIENCE, + jwks_uri=OAUTH_JWKS_URI, + user_claim=OAUTH_USER_CLAIM, + ) + + +def _build_resource_metadata_url() -> str: + """Build the well-known resource metadata URL for WWW-Authenticate headers.""" + if MOUNT_PATH: + return f"{MOUNT_PATH}/.well-known/oauth-protected-resource" + return "/.well-known/oauth-protected-resource" + + +def _wrap_with_auth(app: Any) -> Any: + """Wrap an ASGI app with AuthMiddleware if auth is required.""" + if not AUTH_REQUIRED: + return app + + oauth_validator = _build_oauth_validator() + if not JWT_SECRET and not oauth_validator: + raise ValueError( + "AWARENESS_AUTH_REQUIRED=true requires either " + "AWARENESS_JWT_SECRET or AWARENESS_OAUTH_ISSUER (or both)" + ) + from mcp_awareness.middleware import AuthMiddleware + + return AuthMiddleware( + app, + jwt_secret=JWT_SECRET, + algorithm=JWT_ALGORITHM, + oauth_validator=oauth_validator, + auto_provision=OAUTH_AUTO_PROVISION, + resource_metadata_url=_build_resource_metadata_url(), + ) + + def _run() -> None: if TRANSPORT == "streamable-http" and MOUNT_PATH: import uvicorn from starlette.types import ASGIApp as _ASGIApp - from mcp_awareness.middleware import SecretPathMiddleware + from mcp_awareness.middleware import ( + SecretPathMiddleware, + WellKnownMiddleware, + ) inner_app = mcp.streamable_http_app() app: _ASGIApp = SecretPathMiddleware(inner_app, MOUNT_PATH, _health_response) - if AUTH_REQUIRED: - if not JWT_SECRET: - raise ValueError( - "AWARENESS_JWT_SECRET is required when AWARENESS_AUTH_REQUIRED=true" - ) - from mcp_awareness.middleware import AuthMiddleware + if OAUTH_ISSUER: + app = WellKnownMiddleware(app, OAUTH_ISSUER, HOST, PORT, MOUNT_PATH) - app = AuthMiddleware(app, JWT_SECRET, JWT_ALGORITHM) + app = _wrap_with_auth(app) config = uvicorn.Config(app, host=HOST, port=PORT) server = uvicorn.Server(config) @@ -363,16 +413,14 @@ def _run() -> None: from mcp_awareness.middleware import HealthMiddleware inner_app = mcp.streamable_http_app() - health_app = HealthMiddleware(inner_app, _health_response) + health_app: Any = HealthMiddleware(inner_app, _health_response) + + if OAUTH_ISSUER: + from mcp_awareness.middleware import WellKnownMiddleware - if AUTH_REQUIRED: - if not JWT_SECRET: - raise ValueError( - "AWARENESS_JWT_SECRET is required when AWARENESS_AUTH_REQUIRED=true" - ) - from mcp_awareness.middleware import AuthMiddleware + health_app = WellKnownMiddleware(health_app, OAUTH_ISSUER, HOST, PORT, MOUNT_PATH) - health_app = AuthMiddleware(health_app, JWT_SECRET, JWT_ALGORITHM) # type: ignore[assignment] + health_app = _wrap_with_auth(health_app) config = uvicorn.Config(health_app, host=HOST, port=PORT) server = uvicorn.Server(config) diff --git a/src/mcp_awareness/sql/create_tables.sql b/src/mcp_awareness/sql/create_tables.sql index 18c5267..39ea2d8 100644 --- a/src/mcp_awareness/sql/create_tables.sql +++ b/src/mcp_awareness/sql/create_tables.sql @@ -16,10 +16,19 @@ CREATE TABLE IF NOT EXISTS users ( display_name TEXT, timezone TEXT DEFAULT 'UTC', preferences JSONB NOT NULL DEFAULT '{{}}'::jsonb, + oauth_subject TEXT, + oauth_issuer TEXT, created TIMESTAMPTZ NOT NULL DEFAULT now(), updated TIMESTAMPTZ, deleted TIMESTAMPTZ ); +/* Ensure OAuth columns exist on tables created before migration i4d5e6f7g8h9 */ +ALTER TABLE users ADD COLUMN IF NOT EXISTS oauth_subject TEXT; +ALTER TABLE users ADD COLUMN IF NOT EXISTS oauth_issuer TEXT; +CREATE UNIQUE INDEX IF NOT EXISTS ix_users_oauth_identity + ON users (oauth_issuer, oauth_subject) WHERE oauth_issuer IS NOT NULL; +CREATE INDEX IF NOT EXISTS ix_users_oauth_subject + ON users (oauth_subject) WHERE oauth_subject IS NOT NULL; CREATE TABLE IF NOT EXISTS entries ( id TEXT PRIMARY KEY, diff --git a/src/mcp_awareness/sql/create_user_auto.sql b/src/mcp_awareness/sql/create_user_auto.sql new file mode 100644 index 0000000..03b1b13 --- /dev/null +++ b/src/mcp_awareness/sql/create_user_auto.sql @@ -0,0 +1,7 @@ +/* name: create_user_auto */ +/* mode: literal */ +/* Auto-provision a user on first OAuth login. No-op if user already exists. + Params: user_id, email, display_name, oauth_subject, oauth_issuer */ +INSERT INTO users (id, email, display_name, oauth_subject, oauth_issuer, created) +VALUES (%s, %s, %s, %s, %s, now()) +ON CONFLICT (id) DO NOTHING diff --git a/src/mcp_awareness/sql/get_user.sql b/src/mcp_awareness/sql/get_user.sql new file mode 100644 index 0000000..c886110 --- /dev/null +++ b/src/mcp_awareness/sql/get_user.sql @@ -0,0 +1,4 @@ +/* name: get_user */ +/* mode: literal */ +/* Get a user by ID. Params: user_id */ +SELECT id, email, display_name, timezone, created FROM users WHERE id = %s AND deleted IS NULL diff --git a/src/mcp_awareness/sql/get_user_by_oauth.sql b/src/mcp_awareness/sql/get_user_by_oauth.sql new file mode 100644 index 0000000..5aeba20 --- /dev/null +++ b/src/mcp_awareness/sql/get_user_by_oauth.sql @@ -0,0 +1,6 @@ +/* name: get_user_by_oauth */ +/* mode: literal */ +/* Look up a user by OAuth identity (issuer + subject). Params: oauth_issuer, oauth_subject */ +SELECT id, email, display_name, timezone, oauth_subject, oauth_issuer, created +FROM users +WHERE oauth_issuer = %s AND oauth_subject = %s AND deleted IS NULL diff --git a/src/mcp_awareness/sql/link_oauth_identity.sql b/src/mcp_awareness/sql/link_oauth_identity.sql new file mode 100644 index 0000000..a6505d7 --- /dev/null +++ b/src/mcp_awareness/sql/link_oauth_identity.sql @@ -0,0 +1,10 @@ +/* name: link_oauth_identity */ +/* mode: literal */ +/* Link an OAuth identity to an existing user found by email. + Sets oauth_subject and oauth_issuer on first OAuth login. + Only updates if oauth_subject is currently NULL (first-time link). + Params: oauth_subject, oauth_issuer, email */ +UPDATE users +SET oauth_subject = %s, oauth_issuer = %s, updated = now() +WHERE email = %s AND oauth_subject IS NULL AND deleted IS NULL +RETURNING id diff --git a/tests/test_middleware.py b/tests/test_middleware.py index 3717c2a..4c80323 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -334,7 +334,7 @@ def test_http_auth_required_no_secret_raises(self, monkeypatch: pytest.MonkeyPat mock_app = MagicMock() monkeypatch.setattr(server_mod.mcp, "streamable_http_app", lambda: mock_app) - with pytest.raises(ValueError, match="AWARENESS_JWT_SECRET is required"): + with pytest.raises(ValueError, match="AWARENESS_AUTH_REQUIRED=true requires"): server_mod._run() @@ -528,11 +528,18 @@ async def noop_receive() -> dict[str, Any]: async def noop_send(msg: dict[str, Any]) -> None: pass - with ( - patch("jwt.decode", side_effect=RuntimeError("boom")), - pytest.raises(RuntimeError, match="boom"), - ): - await app(scope, noop_receive, noop_send) + sent: list[dict[str, Any]] = [] + + async def capture_send(msg: dict[str, Any]) -> None: + sent.append(msg) + + with patch("jwt.decode", side_effect=RuntimeError("boom")): + await app(scope, noop_receive, capture_send) + + # With dual auth, unexpected errors in self-signed validation + # fall through gracefully (returns 401, not a crash) + body = b"".join(m.get("body", b"") for m in sent if m["type"] == "http.response.body") + assert b"Invalid token" in body class TestServerAuthWiring: @@ -552,7 +559,7 @@ def test_http_with_mount_path_auth_no_secret_raises( mock_app = MagicMock() monkeypatch.setattr(server_mod.mcp, "streamable_http_app", lambda: mock_app) - with pytest.raises(ValueError, match="AWARENESS_JWT_SECRET is required"): + with pytest.raises(ValueError, match="AWARENESS_AUTH_REQUIRED=true requires"): server_mod._run() def test_http_no_mount_path_auth_required_wraps_with_auth_middleware( diff --git a/tests/test_oauth.py b/tests/test_oauth.py new file mode 100644 index 0000000..4b9735f --- /dev/null +++ b/tests/test_oauth.py @@ -0,0 +1,773 @@ +# mcp-awareness — ambient system awareness for AI agents +# Copyright (C) 2026 Chris Means +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +"""Tests for OAuth 2.1 resource server (JWKS validation, auto-provisioning, metadata).""" + +from __future__ import annotations + +import json +from datetime import datetime, timedelta, timezone +from typing import Any +from unittest.mock import MagicMock + +import jwt +import pytest +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.hazmat.primitives.serialization import ( + Encoding, + NoEncryption, + PrivateFormat, +) + +from mcp_awareness.middleware import AuthMiddleware, WellKnownMiddleware +from mcp_awareness.oauth import OAuthTokenValidator + +# --------------------------------------------------------------------------- +# RSA key pair for testing +# --------------------------------------------------------------------------- + +_TEST_PRIVATE_KEY = rsa.generate_private_key(public_exponent=65537, key_size=2048) +_TEST_PUBLIC_KEY = _TEST_PRIVATE_KEY.public_key() + +TEST_ISSUER = "https://auth.example.com" +TEST_AUDIENCE = "awareness-test" +TEST_OWNER = "test-owner" + + +def _make_token( + sub: str = TEST_OWNER, + issuer: str = TEST_ISSUER, + audience: str = TEST_AUDIENCE, + email: str | None = None, + name: str | None = None, + expired: bool = False, +) -> str: + """Create a signed RS256 JWT for testing.""" + now = datetime.now(timezone.utc) + payload: dict[str, Any] = { + "sub": sub, + "iss": issuer, + "aud": audience, + "iat": now, + "exp": now + timedelta(hours=-1 if expired else 1), + } + if email: + payload["email"] = email + if name: + payload["name"] = name + + private_pem = _TEST_PRIVATE_KEY.private_bytes(Encoding.PEM, PrivateFormat.PKCS8, NoEncryption()) + return jwt.encode(payload, private_pem, algorithm="RS256") + + +# --------------------------------------------------------------------------- +# OAuthTokenValidator tests +# --------------------------------------------------------------------------- + + +class TestOAuthTokenValidator: + def _make_validator(self, audience: str = TEST_AUDIENCE) -> OAuthTokenValidator: + validator = OAuthTokenValidator( + issuer=TEST_ISSUER, + audience=audience, + user_claim="sub", + ) + return validator + + def _mock_jwk_client(self, validator: OAuthTokenValidator) -> None: + """Replace the PyJWKClient with a mock that returns our test key.""" + import time as _time + + mock_client = MagicMock() + mock_signing_key = MagicMock() + mock_signing_key.key = _TEST_PUBLIC_KEY + mock_client.get_signing_key_from_jwt.return_value = mock_signing_key + validator._jwk_client = mock_client + # Prevent validate() from replacing our mock with a real client + validator._last_jwks_fetch = _time.monotonic() + + def test_valid_token(self) -> None: + validator = self._make_validator() + self._mock_jwk_client(validator) + token = _make_token() + result = validator.validate(token) + assert result["owner_id"] == TEST_OWNER + + def test_token_with_email_and_name(self) -> None: + validator = self._make_validator() + self._mock_jwk_client(validator) + token = _make_token(email="alice@example.com", name="Alice") + result = validator.validate(token) + assert result["owner_id"] == TEST_OWNER + assert result["email"] == "alice@example.com" + assert result["name"] == "Alice" + + def test_expired_token_raises(self) -> None: + validator = self._make_validator() + self._mock_jwk_client(validator) + token = _make_token(expired=True) + with pytest.raises(jwt.ExpiredSignatureError): + validator.validate(token) + + def test_wrong_issuer_raises(self) -> None: + validator = self._make_validator() + self._mock_jwk_client(validator) + token = _make_token(issuer="https://wrong-issuer.com") + with pytest.raises(jwt.InvalidIssuerError): + validator.validate(token) + + def test_wrong_audience_raises(self) -> None: + validator = self._make_validator(audience="wrong-audience") + self._mock_jwk_client(validator) + token = _make_token() + with pytest.raises(jwt.InvalidAudienceError): + validator.validate(token) + + def test_missing_sub_claim_raises(self) -> None: + validator = self._make_validator() + self._mock_jwk_client(validator) + # Create a token without sub + now = datetime.now(timezone.utc) + payload = { + "iss": TEST_ISSUER, + "aud": TEST_AUDIENCE, + "iat": now, + "exp": now + timedelta(hours=1), + } + private_pem = _TEST_PRIVATE_KEY.private_bytes( + Encoding.PEM, PrivateFormat.PKCS8, NoEncryption() + ) + token = jwt.encode(payload, private_pem, algorithm="RS256") + with pytest.raises(jwt.InvalidTokenError, match="missing 'sub' claim"): + validator.validate(token) + + def test_custom_user_claim(self) -> None: + validator = OAuthTokenValidator( + issuer=TEST_ISSUER, + audience=TEST_AUDIENCE, + user_claim="email", + ) + self._mock_jwk_client(validator) + token = _make_token(email="alice@example.com") + result = validator.validate(token) + assert result["owner_id"] == "alice@example.com" + + def test_explicit_jwks_uri(self) -> None: + """Explicit jwks_uri overrides the default derived from issuer.""" + validator = OAuthTokenValidator( + issuer=TEST_ISSUER, + audience=TEST_AUDIENCE, + jwks_uri="https://custom.example.com/keys", + user_claim="sub", + ) + assert validator._jwks_uri == "https://custom.example.com/keys" + + def test_jwks_cache_refresh(self) -> None: + """JWKS client is refreshed when cache TTL expires.""" + validator = self._make_validator() + self._mock_jwk_client(validator) + # Force cache to be expired + validator._last_jwks_fetch = 0.0 + validator._jwks_cache_ttl = 0 # instant expiry + token = _make_token() + # This triggers cache refresh (creates new PyJWKClient) which fails + # because it tries to fetch from a fake URL. That's expected. + with pytest.raises(jwt.exceptions.PyJWKClientConnectionError): + validator.validate(token) + + def test_no_audience_skips_validation(self) -> None: + validator = OAuthTokenValidator( + issuer=TEST_ISSUER, + audience="", + user_claim="sub", + ) + self._mock_jwk_client(validator) + token = _make_token() + result = validator.validate(token) + assert result["owner_id"] == TEST_OWNER + + +# --------------------------------------------------------------------------- +# WellKnownMiddleware tests +# --------------------------------------------------------------------------- + + +class TestWellKnownMiddleware: + @pytest.mark.anyio + async def test_serves_protected_resource_metadata(self) -> None: + async def inner_app(scope: Any, receive: Any, send: Any) -> None: + pass + + app = WellKnownMiddleware(inner_app, TEST_ISSUER, "localhost", 8420) + scope = { + "type": "http", + "path": "/.well-known/oauth-protected-resource", + "method": "GET", + "headers": [], + } + + sent: list[dict[str, Any]] = [] + + async def noop_receive() -> dict[str, Any]: + return {"type": "http.request", "body": b""} + + async def capture_send(msg: dict[str, Any]) -> None: + sent.append(msg) + + await app(scope, noop_receive, capture_send) + + body = b"".join(m.get("body", b"") for m in sent if m["type"] == "http.response.body") + data = json.loads(body) + assert data["authorization_servers"] == [TEST_ISSUER] + assert data["token_methods"] == ["Bearer"] + assert "/mcp" in data["resource"] + + @pytest.mark.anyio + async def test_passes_through_other_paths(self) -> None: + called = False + + async def inner_app(scope: Any, receive: Any, send: Any) -> None: + nonlocal called + called = True + + app = WellKnownMiddleware(inner_app, TEST_ISSUER, "localhost", 8420) + scope = {"type": "http", "path": "/mcp", "method": "POST", "headers": []} + + async def noop_receive() -> dict[str, Any]: + return {"type": "http.request", "body": b""} + + async def noop_send(msg: dict[str, Any]) -> None: + pass + + await app(scope, noop_receive, noop_send) + assert called + + +# --------------------------------------------------------------------------- +# AuthMiddleware dual auth tests +# --------------------------------------------------------------------------- + + +class TestDualAuth: + @pytest.mark.anyio + async def test_self_signed_jwt_still_works(self) -> None: + """Existing self-signed JWT auth continues to work with OAuth configured.""" + secret = "test-secret-at-least-32-chars-long!" + token = jwt.encode({"sub": "alice"}, secret, algorithm="HS256") + + called_with_owner: list[str] = [] + + async def inner_app(scope: Any, receive: Any, send: Any) -> None: + from mcp_awareness.server import _owner_ctx + + called_with_owner.append(_owner_ctx.get()) + + mock_oauth = MagicMock() + app = AuthMiddleware(inner_app, jwt_secret=secret, oauth_validator=mock_oauth) + + scope = { + "type": "http", + "path": "/mcp", + "method": "POST", + "headers": [(b"authorization", f"Bearer {token}".encode())], + } + + async def noop_receive() -> dict[str, Any]: + return {"type": "http.request", "body": b""} + + async def noop_send(msg: dict[str, Any]) -> None: + pass + + await app(scope, noop_receive, noop_send) + assert called_with_owner == ["alice"] + # OAuth validator should NOT be called when self-signed JWT succeeds + mock_oauth.validate.assert_not_called() + + @pytest.mark.anyio + async def test_oauth_fallback_when_self_signed_fails(self) -> None: + """When self-signed JWT fails, falls back to OAuth validation.""" + oauth_token = "oauth-token-not-valid-as-self-signed" + + mock_oauth = MagicMock() + mock_oauth.validate.return_value = { + "owner_id": "oauth-user", + "email": "oauth@example.com", + } + + called_with_owner: list[str] = [] + + async def inner_app(scope: Any, receive: Any, send: Any) -> None: + from mcp_awareness.server import _owner_ctx + + called_with_owner.append(_owner_ctx.get()) + + app = AuthMiddleware( + inner_app, + jwt_secret="some-secret-at-least-32-chars!!!", + oauth_validator=mock_oauth, + auto_provision=False, + ) + + scope = { + "type": "http", + "path": "/mcp", + "method": "POST", + "headers": [(b"authorization", f"Bearer {oauth_token}".encode())], + } + + async def noop_receive() -> dict[str, Any]: + return {"type": "http.request", "body": b""} + + async def noop_send(msg: dict[str, Any]) -> None: + pass + + await app(scope, noop_receive, noop_send) + assert called_with_owner == ["oauth-user"] + mock_oauth.validate.assert_called_once_with(oauth_token) + + @pytest.mark.anyio + async def test_well_known_bypasses_auth(self) -> None: + """/.well-known/ paths should not require authentication.""" + sent: list[dict[str, Any]] = [] + + async def inner_app(scope: Any, receive: Any, send: Any) -> None: + pass + + app = AuthMiddleware(inner_app, jwt_secret="secret-at-least-32-chars!!!!!!") + + scope = { + "type": "http", + "path": "/.well-known/oauth-protected-resource", + "method": "GET", + "headers": [], # No auth header + } + + async def noop_receive() -> dict[str, Any]: + return {"type": "http.request", "body": b""} + + async def capture_send(msg: dict[str, Any]) -> None: + sent.append(msg) + + await app(scope, noop_receive, capture_send) + # Should NOT return 401 — well-known paths are public + status_codes = [m.get("status") for m in sent if m["type"] == "http.response.start"] + assert 401 not in status_codes + + @pytest.mark.anyio + async def test_resolve_user_called_on_oauth(self) -> None: + """When OAuth succeeds, _resolve_user is called with claims.""" + mock_oauth = MagicMock() + mock_oauth.validate.return_value = { + "owner_id": "new-user", + "email": "new@example.com", + "name": "New User", + "oauth_subject": "sub-123", + "oauth_issuer": TEST_ISSUER, + } + + called_with_owner: list[str] = [] + + async def inner_app(scope: Any, receive: Any, send: Any) -> None: + from mcp_awareness.server import _owner_ctx + + called_with_owner.append(_owner_ctx.get()) + + app = AuthMiddleware( + inner_app, + jwt_secret="", + oauth_validator=mock_oauth, + auto_provision=False, + ) + + scope = { + "type": "http", + "path": "/mcp", + "method": "POST", + "headers": [(b"authorization", b"Bearer oauth-token")], + } + + async def noop_receive() -> dict[str, Any]: + return {"type": "http.request", "body": b""} + + async def noop_send(msg: dict[str, Any]) -> None: + pass + + await app(scope, noop_receive, noop_send) + assert called_with_owner == ["new-user"] + + @pytest.mark.anyio + async def test_401_includes_www_authenticate(self) -> None: + """401 responses include WWW-Authenticate with resource_metadata URL.""" + sent: list[dict[str, Any]] = [] + + async def inner_app(scope: Any, receive: Any, send: Any) -> None: + pass + + mock_oauth = MagicMock() + mock_oauth.validate.side_effect = Exception("fail") + app = AuthMiddleware( + inner_app, + jwt_secret="", + oauth_validator=mock_oauth, + resource_metadata_url="/.well-known/oauth-protected-resource", + ) + + scope = { + "type": "http", + "path": "/mcp", + "method": "POST", + "headers": [(b"authorization", b"Bearer bad-token")], + } + + async def noop_receive() -> dict[str, Any]: + return {"type": "http.request", "body": b""} + + async def capture_send(msg: dict[str, Any]) -> None: + sent.append(msg) + + await app(scope, noop_receive, capture_send) + + response_start = next(m for m in sent if m["type"] == "http.response.start") + assert response_start["status"] == 401 + headers = dict(response_start.get("headers", [])) + www_auth = headers.get(b"www-authenticate", b"").decode() + assert "resource_metadata" in www_auth + + +# --------------------------------------------------------------------------- +# Auto-provisioning tests +# --------------------------------------------------------------------------- + + +class TestAutoProvisionIntegration: + """Integration test: middleware auto-provision with a real store.""" + + @pytest.mark.anyio + async def test_ensure_user_creates_record(self, store: Any, monkeypatch: Any) -> None: + """_ensure_user calls store.create_user_if_not_exists through the server module.""" + import mcp_awareness.server as server_mod + + # Point the middleware at our test store + monkeypatch.setattr(server_mod, "store", store) + + async def inner_app(scope: Any, receive: Any, send: Any) -> None: + pass + + mock_oauth = MagicMock() + mock_oauth.validate.return_value = { + "owner_id": "integration-user", + "email": "int@example.com", + "name": "Integration", + "oauth_subject": "int-sub", + "oauth_issuer": TEST_ISSUER, + } + + app = AuthMiddleware( + inner_app, + jwt_secret="", + oauth_validator=mock_oauth, + auto_provision=True, + ) + + scope = { + "type": "http", + "path": "/mcp", + "method": "POST", + "headers": [(b"authorization", b"Bearer oauth-token")], + } + + async def noop_receive() -> dict[str, Any]: + return {"type": "http.request", "body": b""} + + async def noop_send(msg: dict[str, Any]) -> None: + pass + + await app(scope, noop_receive, noop_send) + + # Verify user was created in the real store + user = store.get_user("integration-user") + assert user is not None + assert user["email"] == "int@example.com" + + @pytest.mark.anyio + async def test_resolve_finds_already_linked_user(self, store: Any, monkeypatch: Any) -> None: + """OAuth login resolves to existing user via oauth_subject lookup.""" + import mcp_awareness.server as server_mod + + monkeypatch.setattr(server_mod, "store", store) + + # Pre-create a linked user + store.create_user_if_not_exists( + "linked-alice", "alice@example.com", "Alice", "alice-sub", TEST_ISSUER + ) + + called_with_owner: list[str] = [] + + async def inner_app(scope: Any, receive: Any, send: Any) -> None: + from mcp_awareness.server import _owner_ctx + + called_with_owner.append(_owner_ctx.get()) + + mock_oauth = MagicMock() + mock_oauth.validate.return_value = { + "owner_id": "alice-sub", + "email": "alice@example.com", + "oauth_subject": "alice-sub", + "oauth_issuer": TEST_ISSUER, + } + + app = AuthMiddleware( + inner_app, jwt_secret="", oauth_validator=mock_oauth, auto_provision=False + ) + scope = { + "type": "http", + "path": "/mcp", + "method": "POST", + "headers": [(b"authorization", b"Bearer oauth-token")], + } + + async def noop_receive() -> dict[str, Any]: + return {"type": "http.request", "body": b""} + + async def noop_send(msg: dict[str, Any]) -> None: + pass + + await app(scope, noop_receive, noop_send) + # Should resolve to the existing user's ID, not the raw sub claim + assert called_with_owner == ["linked-alice"] + + @pytest.mark.anyio + async def test_resolve_links_pre_provisioned_user_by_email( + self, store: Any, monkeypatch: Any + ) -> None: + """First OAuth login links to a pre-provisioned user matched by email.""" + import mcp_awareness.server as server_mod + + monkeypatch.setattr(server_mod, "store", store) + + # Pre-provision via CLI (no OAuth identity) + store.create_user_if_not_exists("cli-bob", "bob@example.com", "Bob") + + called_with_owner: list[str] = [] + + async def inner_app(scope: Any, receive: Any, send: Any) -> None: + from mcp_awareness.server import _owner_ctx + + called_with_owner.append(_owner_ctx.get()) + + mock_oauth = MagicMock() + mock_oauth.validate.return_value = { + "owner_id": "bob-sub-xyz", + "email": "bob@example.com", + "oauth_subject": "bob-sub-xyz", + "oauth_issuer": TEST_ISSUER, + } + + app = AuthMiddleware( + inner_app, jwt_secret="", oauth_validator=mock_oauth, auto_provision=False + ) + scope = { + "type": "http", + "path": "/mcp", + "method": "POST", + "headers": [(b"authorization", b"Bearer oauth-token")], + } + + async def noop_receive() -> dict[str, Any]: + return {"type": "http.request", "body": b""} + + async def noop_send(msg: dict[str, Any]) -> None: + pass + + await app(scope, noop_receive, noop_send) + # Should resolve to the pre-provisioned user's ID via email linking + assert called_with_owner == ["cli-bob"] + # Verify OAuth identity was linked + user = store.get_user_by_oauth(TEST_ISSUER, "bob-sub-xyz") + assert user is not None + assert user["id"] == "cli-bob" + + +class TestAutoProvisionFailure: + """Verify _ensure_user swallows exceptions gracefully.""" + + @pytest.mark.anyio + async def test_ensure_user_exception_swallowed(self, monkeypatch: Any) -> None: + """Auto-provisioning failure must not block the request.""" + import mcp_awareness.server as server_mod + + broken_store = MagicMock() + broken_store.get_user_by_oauth.return_value = None + broken_store.link_oauth_identity.return_value = None + broken_store.create_user_if_not_exists.side_effect = RuntimeError("db down") + monkeypatch.setattr(server_mod, "store", broken_store) + + mock_oauth = MagicMock() + mock_oauth.validate.return_value = { + "owner_id": "failing-user", + "oauth_subject": "sub", + "oauth_issuer": TEST_ISSUER, + } + + called_with_owner: list[str] = [] + + async def inner_app(scope: Any, receive: Any, send: Any) -> None: + from mcp_awareness.server import _owner_ctx + + called_with_owner.append(_owner_ctx.get()) + + app = AuthMiddleware( + inner_app, jwt_secret="", oauth_validator=mock_oauth, auto_provision=True + ) + scope = { + "type": "http", + "path": "/mcp", + "method": "POST", + "headers": [(b"authorization", b"Bearer token")], + } + + async def noop_receive() -> dict[str, Any]: + return {"type": "http.request", "body": b""} + + async def noop_send(msg: dict[str, Any]) -> None: + pass + + await app(scope, noop_receive, noop_send) + # Request should succeed despite provisioning failure + assert called_with_owner == ["failing-user"] + + +class TestServerWiring: + def test_build_oauth_validator_returns_none_without_issuer(self) -> None: + """No OAuth validator when OAUTH_ISSUER is empty.""" + from mcp_awareness import server as server_mod + + original = server_mod.OAUTH_ISSUER + try: + server_mod.OAUTH_ISSUER = "" + assert server_mod._build_oauth_validator() is None + finally: + server_mod.OAUTH_ISSUER = original + + def test_wrap_with_auth_uses_oauth_when_issuer_set(self) -> None: + """_wrap_with_auth creates AuthMiddleware with OAuth validator when issuer is set.""" + from mcp_awareness import server as server_mod + + orig_issuer = server_mod.OAUTH_ISSUER + orig_required = server_mod.AUTH_REQUIRED + orig_secret = server_mod.JWT_SECRET + try: + server_mod.OAUTH_ISSUER = TEST_ISSUER + server_mod.AUTH_REQUIRED = True + server_mod.JWT_SECRET = "" # No self-signed — OAuth only + + async def dummy(scope: Any, receive: Any, send: Any) -> None: + pass + + wrapped = server_mod._wrap_with_auth(dummy) + assert isinstance(wrapped, AuthMiddleware) + assert wrapped.oauth_validator is not None + finally: + server_mod.OAUTH_ISSUER = orig_issuer + server_mod.AUTH_REQUIRED = orig_required + server_mod.JWT_SECRET = orig_secret + + def test_build_oauth_validator_returns_validator_with_issuer(self) -> None: + """OAuth validator created when OAUTH_ISSUER is set.""" + from mcp_awareness import server as server_mod + + original = server_mod.OAUTH_ISSUER + try: + server_mod.OAUTH_ISSUER = TEST_ISSUER + validator = server_mod._build_oauth_validator() + assert validator is not None + from mcp_awareness.oauth import OAuthTokenValidator + + assert isinstance(validator, OAuthTokenValidator) + finally: + server_mod.OAUTH_ISSUER = original + + +class TestAutoProvisioning: + def test_create_user_with_oauth_identity(self, store: Any) -> None: + """Auto-provisioning stores OAuth identity fields.""" + store.create_user_if_not_exists( + "oauth-carol", "carol@example.com", "Carol", "carol-sub-123", TEST_ISSUER + ) + user = store.get_user("oauth-carol") + assert user is not None + assert user["id"] == "oauth-carol" + + def test_get_user_by_oauth(self, store: Any) -> None: + """Look up user by OAuth issuer + subject pair.""" + store.create_user_if_not_exists( + "oauth-dan", "dan@example.com", "Dan", "dan-sub-456", TEST_ISSUER + ) + user = store.get_user_by_oauth(TEST_ISSUER, "dan-sub-456") + assert user is not None + assert user["id"] == "oauth-dan" + + def test_get_user_by_oauth_not_found(self, store: Any) -> None: + user = store.get_user_by_oauth(TEST_ISSUER, "nonexistent-sub") + assert user is None + + def test_create_user_if_not_exists(self, store: Any) -> None: + """Auto-provisioning creates a user that didn't exist.""" + store.create_user_if_not_exists("oauth-alice", "alice@example.com", "Alice") + user = store.get_user("oauth-alice") + assert user is not None + assert user["id"] == "oauth-alice" + assert user["email"] == "alice@example.com" + assert user["display_name"] == "Alice" + + def test_create_user_if_not_exists_no_op_when_exists(self, store: Any) -> None: + """Auto-provisioning is a no-op for existing users.""" + store.create_user_if_not_exists("oauth-bob", "bob@example.com", "Bob") + # Create again with different email — should NOT update + store.create_user_if_not_exists("oauth-bob", "new@example.com", "New Bob") + user = store.get_user("oauth-bob") + assert user is not None + assert user["email"] == "bob@example.com" # Original preserved + + def test_get_user_returns_none_for_unknown(self, store: Any) -> None: + user = store.get_user("nonexistent-user") + assert user is None + + def test_link_oauth_identity_by_email(self, store: Any) -> None: + """Pre-provisioned user (CLI) gets linked on first OAuth login by email match.""" + # Pre-provision via CLI (no OAuth identity yet) + store.create_user_if_not_exists("pre-user", "pre@example.com", "Pre User") + # First OAuth login — link by email + linked_id = store.link_oauth_identity("pre-sub-789", TEST_ISSUER, "pre@example.com") + assert linked_id == "pre-user" + # Verify OAuth columns populated + user = store.get_user_by_oauth(TEST_ISSUER, "pre-sub-789") + assert user is not None + assert user["id"] == "pre-user" + + def test_link_oauth_identity_no_match(self, store: Any) -> None: + """Linking returns None when no user has the given email.""" + linked_id = store.link_oauth_identity("orphan-sub", TEST_ISSUER, "nobody@example.com") + assert linked_id is None + + def test_link_oauth_identity_already_linked(self, store: Any) -> None: + """Linking is a no-op if user already has an OAuth identity.""" + store.create_user_if_not_exists( + "linked-user", "linked@example.com", "Linked", "existing-sub", TEST_ISSUER + ) + # Try to link again with different sub — should not overwrite + linked_id = store.link_oauth_identity("new-sub", TEST_ISSUER, "linked@example.com") + assert linked_id is None # Already linked, no update