diff --git a/backend/db/models.py b/backend/db/models.py index 0156fd9..5832d8d 100644 --- a/backend/db/models.py +++ b/backend/db/models.py @@ -182,7 +182,4 @@ def validate_extractor_owner( extractor = ( session.query(Extractor).filter_by(uuid=extractor_id, owner_id=user_id).first() ) - if extractor is None: - return False - else: - return True + return extractor is not None diff --git a/backend/server/api/api_key.py b/backend/server/api/api_key.py new file mode 100644 index 0000000..037f549 --- /dev/null +++ b/backend/server/api/api_key.py @@ -0,0 +1,6 @@ +from fastapi.security import APIKeyHeader + +# For actual auth, you'd need to check the key against a database or some other +# data store. Here, we don't need actual auth, just a key that matches +# a UUID +UserToken = APIKeyHeader(name="x-key") diff --git a/backend/server/api/examples.py b/backend/server/api/examples.py index b0fbdc8..5b9f4a1 100644 --- a/backend/server/api/examples.py +++ b/backend/server/api/examples.py @@ -2,11 +2,12 @@ from typing import Any, List from uuid import UUID -from fastapi import APIRouter, Cookie, Depends, HTTPException +from fastapi import APIRouter, Depends, HTTPException from sqlalchemy.orm import Session from typing_extensions import Annotated, TypedDict from db.models import Example, get_session, validate_extractor_owner +from server.api.api_key import UserToken router = APIRouter( prefix="/examples", @@ -36,7 +37,7 @@ def create( create_request: CreateExample, *, session: Session = Depends(get_session), - user_id: UUID = Cookie(...), + user_id: UUID = Depends(UserToken), ) -> CreateExampleResponse: """Endpoint to create an example.""" if not validate_extractor_owner(session, create_request["extractor_id"], user_id): @@ -59,7 +60,7 @@ def list( limit: int = 10, offset: int = 0, session=Depends(get_session), - user_id: UUID = Cookie(...), + user_id: UUID = Depends(UserToken), ) -> List[Any]: """Endpoint to get all examples.""" if not validate_extractor_owner(session, extractor_id, user_id): @@ -76,7 +77,10 @@ def list( @router.delete("/{uuid}") def delete( - uuid: UUID, *, session: Session = Depends(get_session), user_id: UUID = Cookie(...) + uuid: UUID, + *, + session: Session = Depends(get_session), + user_id: UUID = Depends(UserToken), ) -> None: """Endpoint to delete an example.""" extractor_id = session.query(Example).filter_by(uuid=str(uuid)).first().extractor_id diff --git a/backend/server/api/extract.py b/backend/server/api/extract.py index ced707c..d46c245 100644 --- a/backend/server/api/extract.py +++ b/backend/server/api/extract.py @@ -1,12 +1,13 @@ from typing import Literal, Optional from uuid import UUID -from fastapi import APIRouter, Cookie, Depends, File, Form, HTTPException, UploadFile +from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile from sqlalchemy.orm import Session from typing_extensions import Annotated from db.models import Extractor, SharedExtractors, get_session from extraction.parsing import parse_binary_input +from server.api.api_key import UserToken from server.extraction_runnable import ExtractResponse, extract_entire_document from server.models import DEFAULT_MODEL from server.retrieval import extract_from_content @@ -27,7 +28,7 @@ async def extract_using_existing_extractor( file: Optional[UploadFile] = File(None), model_name: Optional[str] = Form(DEFAULT_MODEL), session: Session = Depends(get_session), - user_id: UUID = Cookie(...), + user_id: UUID = Depends(UserToken), ) -> ExtractResponse: """Endpoint that is used with an existing extractor. diff --git a/backend/server/api/extractors.py b/backend/server/api/extractors.py index 64f6f52..9ffc3e9 100644 --- a/backend/server/api/extractors.py +++ b/backend/server/api/extractors.py @@ -2,12 +2,13 @@ from typing import Any, Dict, List from uuid import UUID, uuid4 -from fastapi import APIRouter, Cookie, Depends, HTTPException +from fastapi import APIRouter, Depends, HTTPException from pydantic import BaseModel, Field, validator from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session from db.models import Extractor, SharedExtractors, get_session, validate_extractor_owner +from server.api.api_key import UserToken from server.validators import validate_json_schema router = APIRouter( @@ -60,7 +61,7 @@ def share( uuid: UUID, *, session: Session = Depends(get_session), - user_id: UUID = Cookie(...), + user_id: UUID = Depends(UserToken), ) -> ShareExtractorResponse: """Endpoint to share an extractor. @@ -110,7 +111,7 @@ def create( create_request: CreateExtractor, *, session: Session = Depends(get_session), - user_id: UUID = Cookie(...), + user_id: UUID = Depends(UserToken), ) -> CreateExtractorResponse: """Endpoint to create an extractor.""" @@ -128,7 +129,10 @@ def create( @router.get("/{uuid}") def get( - uuid: UUID, *, session: Session = Depends(get_session), user_id: UUID = Cookie(...) + uuid: UUID, + *, + session: Session = Depends(get_session), + user_id: UUID = Depends(UserToken), ) -> Dict[str, Any]: """Endpoint to get an extractor.""" extractor = ( @@ -151,7 +155,7 @@ def list( limit: int = 10, offset: int = 0, session=Depends(get_session), - user_id: UUID = Cookie(...), + user_id: UUID = Depends(UserToken), ) -> List[Any]: """Endpoint to get all extractors.""" return ( @@ -165,7 +169,10 @@ def list( @router.delete("/{uuid}") def delete( - uuid: UUID, *, session: Session = Depends(get_session), user_id: UUID = Cookie(...) + uuid: UUID, + *, + session: Session = Depends(get_session), + user_id: UUID = Depends(UserToken), ) -> None: """Endpoint to delete an extractor.""" session.query(Extractor).filter_by(uuid=str(uuid), owner_id=user_id).delete() diff --git a/backend/server/main.py b/backend/server/main.py index befdd41..7b2839f 100644 --- a/backend/server/main.py +++ b/backend/server/main.py @@ -68,10 +68,10 @@ def ready() -> str: # Serve the frontend -ui_dir = str(ROOT / "ui") +UI_DIR = str(ROOT / "ui") -if os.path.exists(ui_dir): - app.mount("/", StaticFiles(directory=ui_dir, html=True), name="ui") +if os.path.exists(UI_DIR): + app.mount("/", StaticFiles(directory=UI_DIR, html=True), name="ui") else: logger.warning("No UI directory found, serving API only.") diff --git a/backend/tests/unit_tests/api/test_api_defining_extractors.py b/backend/tests/unit_tests/api/test_api_defining_extractors.py index c9fa661..fe05194 100644 --- a/backend/tests/unit_tests/api/test_api_defining_extractors.py +++ b/backend/tests/unit_tests/api/test_api_defining_extractors.py @@ -9,8 +9,8 @@ async def test_extractors_api() -> None: # First verify that the database is empty async with get_async_client() as client: user_id = str(uuid.uuid4()) - cookies = {"user_id": user_id} - response = await client.get("/extractors", cookies=cookies) + headers = {"x-key": user_id} + response = await client.get("/extractors", headers=headers) assert response.status_code == 200 assert response.json() == [] @@ -21,38 +21,36 @@ async def test_extractors_api() -> None: "instruction": "Test Instruction", } response = await client.post( - "/extractors", json=create_request, cookies=cookies + "/extractors", json=create_request, headers=headers ) assert response.status_code == 200 # Verify that the extractor was created - response = await client.get("/extractors", cookies=cookies) + response = await client.get("/extractors", headers=headers) assert response.status_code == 200 get_response = response.json() assert len(get_response) == 1 - # Check cookies - bad_cookies = {"user_id": str(uuid.uuid4())} - bad_response = await client.get("/extractors", cookies=bad_cookies) + # Check headers + bad_headers = {"x-key": str(uuid.uuid4())} + bad_response = await client.get("/extractors", headers=bad_headers) assert bad_response.status_code == 200 assert len(bad_response.json()) == 0 # Check we need cookie to delete uuid_str = get_response[0]["uuid"] _ = uuid.UUID(uuid_str) # assert valid uuid - bad_response = await client.delete( - f"/extractors/{uuid_str}", cookies=bad_cookies - ) + await client.delete(f"/extractors/{uuid_str}", headers=bad_headers) # Check extractor was not deleted - response = await client.get("/extractors", cookies=cookies) + response = await client.get("/extractors", headers=headers) assert len(response.json()) == 1 # Verify that we can delete an extractor _ = uuid.UUID(uuid_str) # assert valid uuid - response = await client.delete(f"/extractors/{uuid_str}", cookies=cookies) + response = await client.delete(f"/extractors/{uuid_str}", headers=headers) assert response.status_code == 200 - get_response = await client.get("/extractors", cookies=cookies) + get_response = await client.get("/extractors", headers=headers) assert get_response.status_code == 200 assert get_response.json() == [] @@ -63,12 +61,12 @@ async def test_extractors_api() -> None: "instruction": "Test Instruction", } response = await client.post( - "/extractors", json=create_request, cookies=cookies + "/extractors", json=create_request, headers=headers ) assert response.status_code == 200 # Verify that the extractor was created - response = await client.get("/extractors", cookies=cookies) + response = await client.get("/extractors", headers=headers) assert response.status_code == 200 assert len(response.json()) == 1 @@ -76,10 +74,10 @@ async def test_extractors_api() -> None: get_response = response.json() uuid_str = get_response[0]["uuid"] _ = uuid.UUID(uuid_str) # assert valid uuid - response = await client.delete(f"/extractors/{uuid_str}", cookies=cookies) + response = await client.delete(f"/extractors/{uuid_str}", headers=headers) assert response.status_code == 200 - get_response = await client.get("/extractors", cookies=cookies) + get_response = await client.get("/extractors", headers=headers) assert get_response.status_code == 200 assert get_response.json() == [] @@ -92,11 +90,11 @@ async def test_extractors_api() -> None: "instruction": "Test Instruction", } response = await client.post( - "/extractors", json=create_request, cookies=cookies + "/extractors", json=create_request, headers=headers ) extractor_uuid = response.json()["uuid"] assert response.status_code == 200 - response = await client.get(f"/extractors/{extractor_uuid}", cookies=cookies) + response = await client.get(f"/extractors/{extractor_uuid}", headers=headers) response_data = response.json() assert extractor_uuid == response_data["uuid"] assert "my extractor" == response_data["name"] @@ -107,8 +105,8 @@ async def test_sharing_extractor() -> None: """Test sharing an extractor.""" async with get_async_client() as client: user_id = str(uuid.uuid4()) - cookies = {"user_id": user_id} - response = await client.get("/extractors", cookies=cookies) + headers = {"x-key": user_id} + response = await client.get("/extractors", headers=headers) assert response.status_code == 200 assert response.json() == [] # Verify that we can create an extractor @@ -119,28 +117,28 @@ async def test_sharing_extractor() -> None: "instruction": "Test Instruction", } response = await client.post( - "/extractors", json=create_request, cookies=cookies + "/extractors", json=create_request, headers=headers ) assert response.status_code == 200 uuid_str = response.json()["uuid"] # Generate a share uuid - response = await client.post(f"/extractors/{uuid_str}/share", cookies=cookies) + response = await client.post(f"/extractors/{uuid_str}/share", headers=headers) assert response.status_code == 200 assert "share_uuid" in response.json() share_uuid = response.json()["share_uuid"] # Test idempotency - response = await client.post(f"/extractors/{uuid_str}/share", cookies=cookies) + response = await client.post(f"/extractors/{uuid_str}/share", headers=headers) assert response.status_code == 200 assert "share_uuid" in response.json() assert response.json()["share_uuid"] == share_uuid - # Check cookies - bad_cookies = {"user_id": str(uuid.uuid4())} + # Check headers + bad_headers = {"x-key": str(uuid.uuid4())} response = await client.post( - f"/extractors/{uuid_str}/share", cookies=bad_cookies + f"/extractors/{uuid_str}/share", headers=bad_headers ) assert response.status_code == 404 diff --git a/backend/tests/unit_tests/api/test_api_examples.py b/backend/tests/unit_tests/api/test_api_examples.py index 23affc7..d3ccab1 100644 --- a/backend/tests/unit_tests/api/test_api_examples.py +++ b/backend/tests/unit_tests/api/test_api_examples.py @@ -16,7 +16,7 @@ async def test_examples_api() -> None: async with get_async_client() as client: # First create an extractor user_id = str(uuid.uuid4()) - cookies = {"user_id": user_id} + headers = {"x-key": user_id} create_request = { "description": "Test Description", "name": "Test Name", @@ -24,7 +24,7 @@ async def test_examples_api() -> None: "instruction": "Test Instruction", } response = await client.post( - "/extractors", json=create_request, cookies=cookies + "/extractors", json=create_request, headers=headers ) assert response.status_code == 200 # Get the extractor id @@ -32,7 +32,7 @@ async def test_examples_api() -> None: # Let's verify that there are no examples response = await client.get( - "/examples?extractor_id=" + extractor_id, cookies=cookies + "/examples?extractor_id=" + extractor_id, headers=headers ) assert response.status_code == 200 assert response.json() == [] @@ -48,20 +48,20 @@ async def test_examples_api() -> None: } ], } - response = await client.post("/examples", json=create_request, cookies=cookies) + response = await client.post("/examples", json=create_request, headers=headers) assert response.status_code == 200 example_id = response.json()["uuid"] - # Check cookies - bad_cookies = {"user_id": str(uuid.uuid4())} + # Check headers + bad_headers = {"x-key": str(uuid.uuid4())} response = await client.post( - "/examples", json=create_request, cookies=bad_cookies + "/examples", json=create_request, headers=bad_headers ) assert response.status_code == 404 # Verify that the example was created response = await client.get( - "/examples?extractor_id=" + extractor_id, cookies=cookies + "/examples?extractor_id=" + extractor_id, headers=headers ) assert response.status_code == 200 assert len(response.json()) == 1 @@ -82,23 +82,23 @@ async def test_examples_api() -> None: "uuid": example_id, } - # Check cookies + # Check headers response = await client.get( - "/examples?extractor_id=" + extractor_id, cookies=bad_cookies + "/examples?extractor_id=" + extractor_id, headers=bad_headers ) assert response.status_code == 404 # Check we need cookie to delete - response = await client.delete(f"/examples/{example_id}", cookies=bad_cookies) + response = await client.delete(f"/examples/{example_id}", headers=bad_headers) assert response.status_code == 404 # Verify that we can delete an example - response = await client.delete(f"/examples/{example_id}", cookies=cookies) + response = await client.delete(f"/examples/{example_id}", headers=headers) assert response.status_code == 200 # Verify that the example was deleted response = await client.get( - "/examples?extractor_id=" + extractor_id, cookies=cookies + "/examples?extractor_id=" + extractor_id, headers=headers ) assert response.status_code == 200 assert response.json() == [] diff --git a/backend/tests/unit_tests/api/test_api_extract.py b/backend/tests/unit_tests/api/test_api_extract.py index f999d64..2860c6d 100644 --- a/backend/tests/unit_tests/api/test_api_extract.py +++ b/backend/tests/unit_tests/api/test_api_extract.py @@ -42,7 +42,7 @@ async def test_extract_from_file() -> None: """Test extract from file API.""" async with get_async_client() as client: user_id = str(uuid4()) - cookies = {"user_id": user_id} + headers = {"x-key": user_id} # Test with invalid extractor extractor_id = UUID(int=1027) # 1027 is a good number. response = await client.post( @@ -51,7 +51,7 @@ async def test_extract_from_file() -> None: "extractor_id": str(extractor_id), "text": "Test Content", }, - cookies=cookies, + headers=headers, ) assert response.status_code == 404, response.text @@ -63,7 +63,9 @@ async def test_extract_from_file() -> None: "instruction": "Test Instruction", } response = await client.post( - "/extractors", json=create_request, cookies=cookies + "/extractors", + json=create_request, + headers=headers, ) assert response.status_code == 200, response.text # Get the extractor id @@ -78,7 +80,7 @@ async def test_extract_from_file() -> None: "text": "Test Content", "mode": "entire_document", }, - cookies=cookies, + headers=headers, ) assert response.status_code == 200 assert response.json() == {"data": ["Test Conte"]} @@ -92,7 +94,7 @@ async def test_extract_from_file() -> None: "mode": "entire_document", "model_name": "gpt-3.5-turbo", }, - cookies=cookies, + headers=headers, ) assert response.status_code == 200 assert response.json() == {"data": ["Test Conte"]} @@ -105,7 +107,7 @@ async def test_extract_from_file() -> None: "text": "Test Content", "mode": "retrieval", }, - cookies=cookies, + headers=headers, ) assert response.status_code == 200 assert response.json() == {"data": ["Test Conte"]} @@ -123,7 +125,7 @@ async def test_extract_from_file() -> None: "mode": "entire_document", }, files={"file": f}, - cookies=cookies, + headers=headers, ) assert response.status_code == 200, response.text @@ -132,7 +134,7 @@ async def test_extract_from_file() -> None: async def test_extract_from_large_file() -> None: user_id = str(uuid4()) - cookies = {"user_id": user_id} + headers = {"x-key": user_id} async with get_async_client() as client: # First create an extractor create_request = { @@ -142,7 +144,7 @@ async def test_extract_from_large_file() -> None: "instruction": "Test Instruction", } response = await client.post( - "/extractors", json=create_request, cookies=cookies + "/extractors", json=create_request, headers=headers ) assert response.status_code == 200, response.text # Get the extractor id @@ -161,7 +163,7 @@ async def test_extract_from_large_file() -> None: "mode": "entire_document", }, files={"file": f}, - cookies=cookies, + headers=headers, ) assert response.status_code == 413 @@ -181,6 +183,6 @@ async def test_extract_from_large_file() -> None: "mode": "entire_document", }, files={"file": f.name}, - cookies=cookies, + headers=headers, ) assert response.status_code == 413