Skip to content

Commit

Permalink
Update API endpoints and dependencies
Browse files Browse the repository at this point in the history
  • Loading branch information
AndPuQing committed Jan 28, 2024
1 parent 96d70fe commit 4267c59
Show file tree
Hide file tree
Showing 6 changed files with 186 additions and 111 deletions.
192 changes: 170 additions & 22 deletions backend/app/app/models.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,157 @@
# Contents of JWT token
import logging
from datetime import datetime
from typing import Optional, Union

from pydantic import BaseModel, EmailStr, HttpUrl
from sqlmodel import JSON, AutoString, Column, Field, SQLModel
from typing import Any, Optional, Union

from pydantic import EmailStr, HttpUrl
from sqlalchemy.exc import IntegrityError, NoResultFound, OperationalError
from sqlalchemy.orm.exc import FlushError
from sqlmodel import JSON, AutoString, Column, Field, SQLModel, select


class ActiveRecordMixin:
__config__ = None

@property
def primary_key(self):
return self.__mapper__.primary_key_from_instance(self) # type: ignore

@classmethod
def first(cls, session):
statement = select(cls)
return session.exec(statement).first()

@classmethod
def one_by_id(cls, session, id: int):
obj = session.get(cls, id)
return obj

@classmethod
def first_by_field(cls, session, field: str, value: Any):
return cls.first_by_fields(session, {field: value})

@classmethod
def one_by_field(cls, session, field: str, value: Any):
return cls.one_by_fields(session, {field: value})

@classmethod
def first_by_fields(cls, session, fields: dict):
statement = select(cls)
for key, value in fields.items():
statement = statement.where(getattr(cls, key) == value)
try:
return session.exec(statement).first()
except NoResultFound:
logging.error(f"{cls}: first_by_fields failed, NoResultFound")
return None

@classmethod
def one_by_fields(cls, session, fields: dict):
statement = select(cls)
for key, value in fields.items():
statement = statement.where(getattr(cls, key) == value)
try:
return session.exec(statement).one()
except NoResultFound:
logging.error(f"{cls}: one_by_fields failed, NoResultFound")
return None

@classmethod
def all_by_field(cls, session, field: str, value: Any):
statement = select(cls).where(getattr(cls, field) == value)
return session.exec(statement).all()

@classmethod
def all_by_fields(cls, session, fields: dict):
statement = select(cls)
for key, value in fields.items():
statement = statement.where(getattr(cls, key) == value)
return session.exec(statement).all()

@classmethod
def convert_without_saving(
cls,
source: Union[dict, SQLModel],
update: Optional[dict] = None,
) -> SQLModel:
if isinstance(source, SQLModel):
obj = cls.from_orm(source, update=update) # type: ignore
elif isinstance(source, dict):
obj = cls.parse_obj(source, update=update) # type: ignore
return obj

@classmethod
def create(
cls,
session,
source: Union[dict, SQLModel],
update: Optional[dict] = None,
) -> Optional[SQLModel]:
obj = cls.convert_without_saving(source, update)
if obj is None:
return None
if obj.save(session):
return obj
return None

@classmethod
def create_or_update(
cls,
session,
source: Union[dict, SQLModel],
update: Optional[dict] = None,
) -> Optional[SQLModel]:
obj = cls.convert_without_saving(source, update)
if obj is None:
return None
pk = cls.__mapper__.primary_key_from_instance(obj) # type: ignore
if pk[0] is not None:
existing = session.get(cls, pk)
if existing is None:
return None # Error
else:
existing.update(session, obj) # Update
return existing
else:
return cls.create(session, obj) # Create

@classmethod
def count(cls, session) -> int:
return len(cls.all(session))

def refresh(self, session):
session.refresh(self)

def save(self, session) -> bool:
session.add(self)
try:
session.commit()
session.refresh(self)
return True
except (IntegrityError, OperationalError, FlushError) as e:
logging.error(e)
session.rollback()
return False

def update(self, session, source: Union[dict, SQLModel]):
if isinstance(source, SQLModel):
source = source.model_dump(exclude_unset=True)

for key, value in source.items():
setattr(self, key, value)
self.save(session)

def delete(self, session):
session.delete(self)
session.commit()

@classmethod
def all(cls, session):
return session.exec(select(cls)).all()

@classmethod
def delete_all(cls, session):
for obj in cls.all(session):
obj.delete(session)


# Shared properties
Expand Down Expand Up @@ -35,15 +183,15 @@ class UserUpdate(UserBase):
password: Union[str, None] = None


class UserUpdateMe(BaseModel):
class UserUpdateMe(SQLModel):
password: Union[str, None] = None
full_name: Union[str, None] = None
email: Union[EmailStr, None] = None
subscription: Union[list[str], None] = None


# Database model, database table inferred from class name
class User(UserBase, table=True):
class User(ActiveRecordMixin, UserBase, table=True):
id: Union[int, None] = Field(default=None, primary_key=True)
hashed_password: str

Expand All @@ -61,11 +209,15 @@ class ItemBase(SQLModel):
default=None,
sa_column=Column(JSON),
)
authors: Union[list[str], None] = Field(
default=None,
sa_column=Column(JSON),
)
url: Optional[str]


# Properties to receive on item creation
class ItemCreate(ItemBase):
title: str
url: Optional[HttpUrl] = None


Expand All @@ -79,23 +231,19 @@ class ItemUpdate(ItemBase):


# Database model, database table inferred from class name
class Item(ItemBase, table=True):
class Item(ActiveRecordMixin, ItemBase, table=True):
id: Union[int, None] = Field(default=None, primary_key=True)
is_hidden: bool = False
url: Optional[str]
authors: Union[list[str], None] = Field(
default=None,
sa_column=Column(JSON),
)

from_source: str = Field(nullable=False)
category: Union[list[str], None] = Field(
default=None,
sa_column=Column(JSON),
)
last_updated: datetime = Field(
default_factory=datetime.utcnow,
nullable=False,
)
category: Union[list[str], None] = Field(
default=None,
sa_column=Column(JSON),
)


# Properties to return via API, id is always required
Expand All @@ -112,21 +260,21 @@ class CrawledItem(SQLModel, table=True):
)


class TokenPayload(BaseModel):
class TokenPayload(SQLModel):
sub: Union[int, None] = None


# Generic message
class Message(BaseModel):
class Message(SQLModel):
message: str


class NewPassword(BaseModel):
class NewPassword(SQLModel):
token: str
new_password: str


# JSON payload containing access token
class Token(BaseModel):
class Token(SQLModel):
access_token: str
token_type: str = "bearer"
42 changes: 0 additions & 42 deletions backend/app/app/schemas/item.py

This file was deleted.

39 changes: 0 additions & 39 deletions backend/app/app/schemas/user.py

This file was deleted.

3 changes: 2 additions & 1 deletion backend/app/app/web/api/deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from jose import jwt
from jose.exceptions import JWTError
from pydantic import ValidationError
from sqlmodel import Session

Expand Down Expand Up @@ -34,7 +35,7 @@ def get_current_user(session: SessionDep, token: TokenDep) -> User:
algorithms=[security.ALGORITHM],
)
token_data = TokenPayload(**payload)
except (jwt.JWTError, ValidationError):
except (JWTError, ValidationError):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Could not validate credentials",
Expand Down
17 changes: 11 additions & 6 deletions backend/app/app/web/api/endpoints/items.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,12 @@
@router.get("/", response_model=list[ItemOut])
def read_items(
session: SessionDep,
current_user: CurrentUser,
skip: int = 0,
limit: int = 100,
) -> list[ItemOut]:
) -> Any:
"""
Retrieve items.
"""
if current_user.is_superuser:
raise HTTPException(status_code=400, detail="Not enough permissions")
statement = select(Item).offset(skip).limit(limit)
return session.exec(statement).all()

Expand Down Expand Up @@ -66,7 +63,11 @@ def create_items(*, session: SessionDep, items_in: list[ItemCreate]) -> Any:

@router.put("/{id}", response_model=ItemOut)
def update_item(
*, session: SessionDep, current_user: CurrentUser, id: int, item_in: ItemUpdate
*,
session: SessionDep,
current_user: CurrentUser,
id: int,
item_in: ItemUpdate
) -> Any:
"""
Update an item.
Expand All @@ -81,7 +82,11 @@ def update_item(


@router.delete("/{id}", response_model=ItemOut)
def delete_item(session: SessionDep, current_user: CurrentUser, id: int) -> ItemOut:
def delete_item(
session: SessionDep,
current_user: CurrentUser,
id: int,
) -> ItemOut:
"""
Delete an item.
"""
Expand Down
Loading

0 comments on commit 4267c59

Please sign in to comment.