Skip to content

Commit

Permalink
x
Browse files Browse the repository at this point in the history
  • Loading branch information
eyurtsev committed Mar 22, 2024
1 parent 59fbc8d commit f3e71cc
Show file tree
Hide file tree
Showing 9 changed files with 85 additions and 70 deletions.
5 changes: 1 addition & 4 deletions backend/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,4 @@ def validate_extractor_owner(
extractor = (
session.query(Extractor).filter_by(uuid=extractor_id, owner_id=user_id).first()
)
if extractor is None:
return False
else:
return True
return extractor is not None
6 changes: 6 additions & 0 deletions backend/server/api/api_key.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from fastapi.security import APIKeyHeader

# For actual auth, you'd need to check the key against a database or some other
# data store. Here, we don't need actual auth, just a key that matches
# a UUID
UserToken = APIKeyHeader(name="x-key")
12 changes: 8 additions & 4 deletions backend/server/api/examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
from typing import Any, List
from uuid import UUID

from fastapi import APIRouter, Cookie, Depends, HTTPException
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from typing_extensions import Annotated, TypedDict

from db.models import Example, get_session, validate_extractor_owner
from server.api.api_key import UserToken

router = APIRouter(
prefix="/examples",
Expand Down Expand Up @@ -36,7 +37,7 @@ def create(
create_request: CreateExample,
*,
session: Session = Depends(get_session),
user_id: UUID = Cookie(...),
user_id: UUID = Depends(UserToken),
) -> CreateExampleResponse:
"""Endpoint to create an example."""
if not validate_extractor_owner(session, create_request["extractor_id"], user_id):
Expand All @@ -59,7 +60,7 @@ def list(
limit: int = 10,
offset: int = 0,
session=Depends(get_session),
user_id: UUID = Cookie(...),
user_id: UUID = Depends(UserToken),
) -> List[Any]:
"""Endpoint to get all examples."""
if not validate_extractor_owner(session, extractor_id, user_id):
Expand All @@ -76,7 +77,10 @@ def list(

@router.delete("/{uuid}")
def delete(
uuid: UUID, *, session: Session = Depends(get_session), user_id: UUID = Cookie(...)
uuid: UUID,
*,
session: Session = Depends(get_session),
user_id: UUID = Depends(UserToken),
) -> None:
"""Endpoint to delete an example."""
extractor_id = session.query(Example).filter_by(uuid=str(uuid)).first().extractor_id
Expand Down
5 changes: 3 additions & 2 deletions backend/server/api/extract.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from typing import Literal, Optional
from uuid import UUID

from fastapi import APIRouter, Cookie, Depends, File, Form, HTTPException, UploadFile
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile
from sqlalchemy.orm import Session
from typing_extensions import Annotated

from db.models import Extractor, SharedExtractors, get_session
from extraction.parsing import parse_binary_input
from server.api.api_key import UserToken
from server.extraction_runnable import ExtractResponse, extract_entire_document
from server.models import DEFAULT_MODEL
from server.retrieval import extract_from_content
Expand All @@ -27,7 +28,7 @@ async def extract_using_existing_extractor(
file: Optional[UploadFile] = File(None),
model_name: Optional[str] = Form(DEFAULT_MODEL),
session: Session = Depends(get_session),
user_id: UUID = Cookie(...),
user_id: UUID = Depends(UserToken),
) -> ExtractResponse:
"""Endpoint that is used with an existing extractor.
Expand Down
19 changes: 13 additions & 6 deletions backend/server/api/extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
from typing import Any, Dict, List
from uuid import UUID, uuid4

from fastapi import APIRouter, Cookie, Depends, HTTPException
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field, validator
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session

from db.models import Extractor, SharedExtractors, get_session, validate_extractor_owner
from server.api.api_key import UserToken
from server.validators import validate_json_schema

router = APIRouter(
Expand Down Expand Up @@ -60,7 +61,7 @@ def share(
uuid: UUID,
*,
session: Session = Depends(get_session),
user_id: UUID = Cookie(...),
user_id: UUID = Depends(UserToken),
) -> ShareExtractorResponse:
"""Endpoint to share an extractor.
Expand Down Expand Up @@ -110,7 +111,7 @@ def create(
create_request: CreateExtractor,
*,
session: Session = Depends(get_session),
user_id: UUID = Cookie(...),
user_id: UUID = Depends(UserToken),
) -> CreateExtractorResponse:
"""Endpoint to create an extractor."""

Expand All @@ -128,7 +129,10 @@ def create(

@router.get("/{uuid}")
def get(
uuid: UUID, *, session: Session = Depends(get_session), user_id: UUID = Cookie(...)
uuid: UUID,
*,
session: Session = Depends(get_session),
user_id: UUID = Depends(UserToken),
) -> Dict[str, Any]:
"""Endpoint to get an extractor."""
extractor = (
Expand All @@ -151,7 +155,7 @@ def list(
limit: int = 10,
offset: int = 0,
session=Depends(get_session),
user_id: UUID = Cookie(...),
user_id: UUID = Depends(UserToken),
) -> List[Any]:
"""Endpoint to get all extractors."""
return (
Expand All @@ -165,7 +169,10 @@ def list(

@router.delete("/{uuid}")
def delete(
uuid: UUID, *, session: Session = Depends(get_session), user_id: UUID = Cookie(...)
uuid: UUID,
*,
session: Session = Depends(get_session),
user_id: UUID = Depends(UserToken),
) -> None:
"""Endpoint to delete an extractor."""
session.query(Extractor).filter_by(uuid=str(uuid), owner_id=user_id).delete()
Expand Down
6 changes: 3 additions & 3 deletions backend/server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,10 @@ def ready() -> str:


# Serve the frontend
ui_dir = str(ROOT / "ui")
UI_DIR = str(ROOT / "ui")

if os.path.exists(ui_dir):
app.mount("/", StaticFiles(directory=ui_dir, html=True), name="ui")
if os.path.exists(UI_DIR):
app.mount("/", StaticFiles(directory=UI_DIR, html=True), name="ui")
else:
logger.warning("No UI directory found, serving API only.")

Expand Down
52 changes: 25 additions & 27 deletions backend/tests/unit_tests/api/test_api_defining_extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ async def test_extractors_api() -> None:
# First verify that the database is empty
async with get_async_client() as client:
user_id = str(uuid.uuid4())
cookies = {"user_id": user_id}
response = await client.get("/extractors", cookies=cookies)
headers = {"x-key": user_id}
response = await client.get("/extractors", headers=headers)
assert response.status_code == 200
assert response.json() == []

Expand All @@ -21,38 +21,36 @@ async def test_extractors_api() -> None:
"instruction": "Test Instruction",
}
response = await client.post(
"/extractors", json=create_request, cookies=cookies
"/extractors", json=create_request, headers=headers
)
assert response.status_code == 200

# Verify that the extractor was created
response = await client.get("/extractors", cookies=cookies)
response = await client.get("/extractors", headers=headers)
assert response.status_code == 200
get_response = response.json()
assert len(get_response) == 1

# Check cookies
bad_cookies = {"user_id": str(uuid.uuid4())}
bad_response = await client.get("/extractors", cookies=bad_cookies)
# Check headers
bad_headers = {"x-key": str(uuid.uuid4())}
bad_response = await client.get("/extractors", headers=bad_headers)
assert bad_response.status_code == 200
assert len(bad_response.json()) == 0

# Check we need cookie to delete
uuid_str = get_response[0]["uuid"]
_ = uuid.UUID(uuid_str) # assert valid uuid
bad_response = await client.delete(
f"/extractors/{uuid_str}", cookies=bad_cookies
)
await client.delete(f"/extractors/{uuid_str}", headers=bad_headers)
# Check extractor was not deleted
response = await client.get("/extractors", cookies=cookies)
response = await client.get("/extractors", headers=headers)
assert len(response.json()) == 1

# Verify that we can delete an extractor
_ = uuid.UUID(uuid_str) # assert valid uuid
response = await client.delete(f"/extractors/{uuid_str}", cookies=cookies)
response = await client.delete(f"/extractors/{uuid_str}", headers=headers)
assert response.status_code == 200

get_response = await client.get("/extractors", cookies=cookies)
get_response = await client.get("/extractors", headers=headers)
assert get_response.status_code == 200
assert get_response.json() == []

Expand All @@ -63,23 +61,23 @@ async def test_extractors_api() -> None:
"instruction": "Test Instruction",
}
response = await client.post(
"/extractors", json=create_request, cookies=cookies
"/extractors", json=create_request, headers=headers
)
assert response.status_code == 200

# Verify that the extractor was created
response = await client.get("/extractors", cookies=cookies)
response = await client.get("/extractors", headers=headers)
assert response.status_code == 200
assert len(response.json()) == 1

# Verify that we can delete an extractor
get_response = response.json()
uuid_str = get_response[0]["uuid"]
_ = uuid.UUID(uuid_str) # assert valid uuid
response = await client.delete(f"/extractors/{uuid_str}", cookies=cookies)
response = await client.delete(f"/extractors/{uuid_str}", headers=headers)
assert response.status_code == 200

get_response = await client.get("/extractors", cookies=cookies)
get_response = await client.get("/extractors", headers=headers)
assert get_response.status_code == 200
assert get_response.json() == []

Expand All @@ -92,11 +90,11 @@ async def test_extractors_api() -> None:
"instruction": "Test Instruction",
}
response = await client.post(
"/extractors", json=create_request, cookies=cookies
"/extractors", json=create_request, headers=headers
)
extractor_uuid = response.json()["uuid"]
assert response.status_code == 200
response = await client.get(f"/extractors/{extractor_uuid}", cookies=cookies)
response = await client.get(f"/extractors/{extractor_uuid}", headers=headers)
response_data = response.json()
assert extractor_uuid == response_data["uuid"]
assert "my extractor" == response_data["name"]
Expand All @@ -107,8 +105,8 @@ async def test_sharing_extractor() -> None:
"""Test sharing an extractor."""
async with get_async_client() as client:
user_id = str(uuid.uuid4())
cookies = {"user_id": user_id}
response = await client.get("/extractors", cookies=cookies)
headers = {"x-key": user_id}
response = await client.get("/extractors", headers=headers)
assert response.status_code == 200
assert response.json() == []
# Verify that we can create an extractor
Expand All @@ -119,28 +117,28 @@ async def test_sharing_extractor() -> None:
"instruction": "Test Instruction",
}
response = await client.post(
"/extractors", json=create_request, cookies=cookies
"/extractors", json=create_request, headers=headers
)
assert response.status_code == 200

uuid_str = response.json()["uuid"]

# Generate a share uuid
response = await client.post(f"/extractors/{uuid_str}/share", cookies=cookies)
response = await client.post(f"/extractors/{uuid_str}/share", headers=headers)
assert response.status_code == 200
assert "share_uuid" in response.json()
share_uuid = response.json()["share_uuid"]

# Test idempotency
response = await client.post(f"/extractors/{uuid_str}/share", cookies=cookies)
response = await client.post(f"/extractors/{uuid_str}/share", headers=headers)
assert response.status_code == 200
assert "share_uuid" in response.json()
assert response.json()["share_uuid"] == share_uuid

# Check cookies
bad_cookies = {"user_id": str(uuid.uuid4())}
# Check headers
bad_headers = {"x-key": str(uuid.uuid4())}
response = await client.post(
f"/extractors/{uuid_str}/share", cookies=bad_cookies
f"/extractors/{uuid_str}/share", headers=bad_headers
)
assert response.status_code == 404

Expand Down
26 changes: 13 additions & 13 deletions backend/tests/unit_tests/api/test_api_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,23 @@ async def test_examples_api() -> None:
async with get_async_client() as client:
# First create an extractor
user_id = str(uuid.uuid4())
cookies = {"user_id": user_id}
headers = {"x-key": user_id}
create_request = {
"description": "Test Description",
"name": "Test Name",
"schema": {"type": "object"},
"instruction": "Test Instruction",
}
response = await client.post(
"/extractors", json=create_request, cookies=cookies
"/extractors", json=create_request, headers=headers
)
assert response.status_code == 200
# Get the extractor id
extractor_id = response.json()["uuid"]

# Let's verify that there are no examples
response = await client.get(
"/examples?extractor_id=" + extractor_id, cookies=cookies
"/examples?extractor_id=" + extractor_id, headers=headers
)
assert response.status_code == 200
assert response.json() == []
Expand All @@ -48,20 +48,20 @@ async def test_examples_api() -> None:
}
],
}
response = await client.post("/examples", json=create_request, cookies=cookies)
response = await client.post("/examples", json=create_request, headers=headers)
assert response.status_code == 200
example_id = response.json()["uuid"]

# Check cookies
bad_cookies = {"user_id": str(uuid.uuid4())}
# Check headers
bad_headers = {"x-key": str(uuid.uuid4())}
response = await client.post(
"/examples", json=create_request, cookies=bad_cookies
"/examples", json=create_request, headers=bad_headers
)
assert response.status_code == 404

# Verify that the example was created
response = await client.get(
"/examples?extractor_id=" + extractor_id, cookies=cookies
"/examples?extractor_id=" + extractor_id, headers=headers
)
assert response.status_code == 200
assert len(response.json()) == 1
Expand All @@ -82,23 +82,23 @@ async def test_examples_api() -> None:
"uuid": example_id,
}

# Check cookies
# Check headers
response = await client.get(
"/examples?extractor_id=" + extractor_id, cookies=bad_cookies
"/examples?extractor_id=" + extractor_id, headers=bad_headers
)
assert response.status_code == 404

# Check we need cookie to delete
response = await client.delete(f"/examples/{example_id}", cookies=bad_cookies)
response = await client.delete(f"/examples/{example_id}", headers=bad_headers)
assert response.status_code == 404

# Verify that we can delete an example
response = await client.delete(f"/examples/{example_id}", cookies=cookies)
response = await client.delete(f"/examples/{example_id}", headers=headers)
assert response.status_code == 200

# Verify that the example was deleted
response = await client.get(
"/examples?extractor_id=" + extractor_id, cookies=cookies
"/examples?extractor_id=" + extractor_id, headers=headers
)
assert response.status_code == 200
assert response.json() == []
Loading

0 comments on commit f3e71cc

Please sign in to comment.