Skip to content

Commit 603e0c6

Browse files
committed
Auth
1 parent 739b786 commit 603e0c6

File tree

17 files changed

+817
-2
lines changed

17 files changed

+817
-2
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ dependencies = [
3333
"starlette>=0.47.1",
3434
"aiohttp>=3.12.14",
3535
"authlib>=1.6.0",
36+
"jsonpath-ng>=1.6.1",
3637
]
3738

3839
[tool.pyright]

src/app/endpoints/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
from models.config import Configuration
99
from configuration import configuration
10+
from authorization.middleware import authorize
11+
from authorization.models import Action
1012
from utils.endpoints import check_configuration_loaded
1113

1214
logger = logging.getLogger(__name__)
@@ -56,6 +58,7 @@
5658

5759

5860
@router.get("/config", responses=get_config_responses)
61+
@authorize(Action.GET_CONFIG)
5962
def config_endpoint_handler(_request: Request) -> Configuration:
6063
"""Handle requests to the /config endpoint."""
6164
# ensure that configuration is loaded

src/app/endpoints/conversations.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from configuration import configuration
1212
from models.responses import ConversationResponse, ConversationDeleteResponse
1313
from auth import get_auth_dependency
14+
from authorization.middleware import authorize
15+
from authorization.models import Action
1416
from utils.endpoints import check_configuration_loaded
1517
from utils.suid import check_suid
1618

@@ -110,6 +112,7 @@ def simplify_session_data(session_data: dict) -> list[dict[str, Any]]:
110112

111113

112114
@router.get("/conversations/{conversation_id}", responses=conversation_responses)
115+
@authorize(Action.GET_CONVERSATION)
113116
async def get_conversation_endpoint_handler(
114117
conversation_id: str,
115118
_auth: Any = Depends(auth_dependency),
@@ -179,6 +182,7 @@ async def get_conversation_endpoint_handler(
179182
@router.delete(
180183
"/conversations/{conversation_id}", responses=conversation_delete_responses
181184
)
185+
@authorize(Action.DELETE_CONVERSATION)
182186
async def delete_conversation_endpoint_handler(
183187
conversation_id: str,
184188
_auth: Any = Depends(auth_dependency),

src/app/endpoints/feedback.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010
from auth import get_auth_dependency
1111
from auth.interface import AuthTuple
12+
from authorization.middleware import authorize
13+
from authorization.models import Action
1214
from configuration import configuration
1315
from models.responses import (
1416
FeedbackResponse,
@@ -64,6 +66,7 @@ async def assert_feedback_enabled(_request: Request) -> None:
6466

6567

6668
@router.post("", responses=feedback_response)
69+
@authorize(Action.FEEDBACK)
6770
def feedback_endpoint_handler(
6871
feedback_request: FeedbackRequest,
6972
auth: Annotated[AuthTuple, Depends(auth_dependency)],

src/app/endpoints/metrics.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,15 @@
77
CONTENT_TYPE_LATEST,
88
)
99

10+
from authorization.middleware import authorize
11+
from authorization.models import Action
1012
from metrics.utils import setup_model_metrics
1113

1214
router = APIRouter(tags=["metrics"])
1315

1416

1517
@router.get("/metrics", response_class=PlainTextResponse)
18+
@authorize(Action.GET_METRICS)
1619
async def metrics_endpoint_handler(_request: Request) -> PlainTextResponse:
1720
"""Handle request to the /metrics endpoint."""
1821
# Setup the model metrics if not already done. This is a one-time setup

src/app/endpoints/models.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,25 @@
33
import logging
44
from typing import Any
55

6+
from fastapi.params import Depends
67
from llama_stack_client import APIConnectionError
78
from fastapi import APIRouter, HTTPException, Request, status
89

910
from client import AsyncLlamaStackClientHolder
1011
from configuration import configuration
12+
from authorization.middleware import authorize
13+
from authorization.models import Action
1114
from models.responses import ModelsResponse
1215
from utils.endpoints import check_configuration_loaded
16+
from auth import get_auth_dependency
1317

1418
logger = logging.getLogger(__name__)
1519
router = APIRouter(tags=["models"])
1620

1721

22+
auth_dependency = get_auth_dependency()
23+
24+
1825
models_responses: dict[int | str, dict[str, Any]] = {
1926
200: {
2027
"models": [
@@ -43,8 +50,15 @@
4350

4451

4552
@router.get("/models", responses=models_responses)
46-
async def models_endpoint_handler(_request: Request) -> ModelsResponse:
53+
@authorize(Action.GET_MODELS)
54+
async def models_endpoint_handler(
55+
_request: Request, auth: Any = Depends(get_auth_dependency())
56+
) -> ModelsResponse:
4757
"""Handle requests to the /models endpoint."""
58+
59+
# Used only by the middleware
60+
_ = auth
61+
4862
check_configuration_loaded(configuration)
4963

5064
llama_stack_configuration = configuration.llama_stack_configuration

src/app/endpoints/query.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
from models.responses import QueryResponse, UnauthorizedResponse, ForbiddenResponse
2727
from models.requests import QueryRequest, Attachment
2828
import constants
29+
from authorization.middleware import authorize
30+
from authorization.models import Action
2931
from utils.endpoints import check_configuration_loaded, get_agent, get_system_prompt
3032
from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups
3133
from utils.suid import get_suid
@@ -66,6 +68,7 @@ def is_transcripts_enabled() -> bool:
6668

6769

6870
@router.post("/query", responses=query_response)
71+
@authorize(Action.QUERY)
6972
async def query_endpoint_handler(
7073
query_request: QueryRequest,
7174
auth: Annotated[AuthTuple, Depends(auth_dependency)],

src/app/endpoints/streaming_query.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919

2020
from auth import get_auth_dependency
2121
from auth.interface import AuthTuple
22+
from authorization.middleware import authorize
23+
from authorization.models import Action
2224
from client import AsyncLlamaStackClientHolder
2325
from configuration import configuration
2426
import metrics
@@ -380,6 +382,7 @@ def _handle_heartbeat_event(chunk_id: int) -> Iterator[str]:
380382

381383

382384
@router.post("/streaming_query")
385+
@authorize(Action.STREAMING_QUERY)
383386
async def streaming_query_endpoint_handler(
384387
_request: Request,
385388
query_request: QueryRequest,

src/auth/jwk_token.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import logging
44
from asyncio import Lock
55
from typing import Any, Callable
6+
import json
67

78
from fastapi import Request, HTTPException, status
89
from authlib.jose import JsonWebKey, KeySet, jwt, Key
@@ -188,4 +189,4 @@ async def __call__(self, request: Request) -> tuple[str, str, str]:
188189

189190
logger.info("Successfully authenticated user %s (ID: %s)", username, user_id)
190191

191-
return user_id, username, user_token
192+
return user_id, username, json.dumps(claims)

src/authorization/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Authorization module for role-based access control."""

0 commit comments

Comments
 (0)